From 75ec504c85e22792e1b72e4195fb324865a32dbd Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Tue, 6 Jan 2026 19:29:11 +0800 Subject: [PATCH 01/43] Support MThreads (MUSA) GPU (#1162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for Moore Threads (MUSA) GPU platform, expanding LightLLM's hardware compatibility. *NOTE:* 1. `_fwd_kernel_token_att1` has been slightly updated to ensure compatibility with the Triton version. 2. `has_mtlink` will be used in upcoming enhancements to enable multi-GPU support. 3. `torch` / `torch_musa` need to be upgraded to the latest versions. ### Testing Done ```bash root@worker3218:/ws# python -m lightllm.server.api_server --model_dir /home/dist/Qwen3-0.6B/ --disable_cudagraph --host 0.0.0.0 WARNING 01-02 12:22:47 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-02 12:22:47 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. INFO 01-02 12:22:48 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:22:48 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:22:48 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:22:48 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:22:48 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-02 12:22:48 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-02 12:22:48 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-02 12:22:48 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm WARNING 01-02 12:22:48 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! INFO 01-02 12:22:48 [shm_size_check.py:21] SHM check: Available=500.00 GB,Recommended=2.32 GB.Sufficient: True INFO 01-02 12:22:48 [api_start.py:94] zmq mode head: ipc:///tmp/_28765_0_ INFO 01-02 12:22:48 [api_start.py:96] use tgi api: False INFO 01-02 12:22:48 [api_start.py:233] alloced ports: [10105, 10128, 10009, 10002, 10268, 10173, 10255, 10190, 10225, 10305] INFO 01-02 12:22:48 [api_start.py:284] all start args:Namespace(run_mode='normal', host='0.0.0.0', port=8000, httpserver_workers=1, zmq_mode='ipc:///tmp/_28765_0_', pd_master_ip='0.0.0.0', pd_master_port=1212, pd_decode_rpyc_port=42000, select_p_d_node_strategy='round_robin', config_server_host=None, config_server_port=None, nixl_pd_kv_page_num=16, nixl_pd_kv_page_size=1024, model_name='default_model_name', model_dir='/home/dist/Qwen3-0.6B/', tokenizer_mode='fast', load_way='HF', max_total_token_num=None, mem_fraction=0.9, batch_max_tokens=8448, eos_id=[151645], tool_call_parser=None, reasoning_parser=None, chat_template=None, running_max_req_size=1000, nnodes=1, node_rank=0, multinode_httpmanager_port=12345, multinode_router_gloo_port=20001, tp=1, dp=1, dp_balancer='bs_balancer', max_req_total_len=16384, nccl_host='127.0.0.1', nccl_port=28765, use_config_server_to_init_nccl=False, mode=[], trust_remote_code=False, disable_log_stats=False, log_stats_interval=10, disable_shm_warning=False, router_token_ratio=0.0, router_max_new_token_len=1024, router_max_wait_tokens=1, disable_aggressive_schedule=False, use_dynamic_prompt_cache=False, disable_dynamic_prompt_cache=False, chunked_prefill_size=4096, disable_chunked_prefill=False, diverse_mode=False, token_healing_mode=False, output_constraint_mode='none', first_token_constraint_mode=False, enable_multimodal=False, enable_multimodal_audio=False, enable_mps=False, disable_custom_allreduce=False, enable_custom_allgather=False, enable_tpsp_mix_mode=False, enable_dp_prefill_balance=False, enable_prefill_microbatch_overlap=False, enable_decode_microbatch_overlap=False, enable_flashinfer_prefill=False, enable_flashinfer_decode=False, enable_fa3=False, cache_capacity=200, embed_cache_storage_size=4, data_type='bfloat16', return_all_prompt_logprobs=False, use_reward_model=False, long_truncation_mode=None, use_tgi_api=False, health_monitor=False, metric_gateway=None, job_name='lightllm', grouping_key=[], push_interval=10, visual_infer_batch_size=1, visual_send_batch_size=1, visual_gpu_ids=[0], visual_tp=1, visual_dp=1, visual_nccl_ports=[29500], enable_monitor_auth=False, disable_cudagraph=True, enable_prefill_cudagraph=False, prefll_cudagraph_max_handle_token=512, graph_max_batch_size=256, graph_split_batch_size=32, graph_grow_step_size=16, graph_max_len_in_batch=16384, quant_type='none', quant_cfg=None, vit_quant_type='none', vit_quant_cfg=None, sampling_backend='triton', penalty_counter_mode='gpu_counter', ep_redundancy_expert_config_path=None, auto_update_redundancy_expert=False, enable_fused_shared_experts=False, mtp_mode=None, mtp_draft_model_dir=None, mtp_step=0, kv_quant_calibration_config_path=None, schedule_time_interval=0.03, enable_cpu_cache=False, cpu_cache_storage_size=2, cpu_cache_token_page_size=256, enable_disk_cache=False, disk_cache_storage_size=10, disk_cache_dir=None, enable_dp_prompt_cache_fetch=False, router_port=10105, detokenization_port=10128, http_server_port=10009, visual_port=10002, audio_port=10268, cache_port=10173, metric_port=10255, multi_level_kv_cache_port=10190, pd_node_infer_rpyc_ports=[10305], pd_node_id=294623010895931863621527973304373176200, pd_p_allowed_port_min=20000, pd_p_allowed_port_max=30000) WARNING 01-02 12:22:55 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-02 12:22:55 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. INFO 01-02 12:22:55 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:22:55 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:22:55 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:22:55 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:22:55 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-02 12:22:55 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. 2026-01-02 12:22:55 | server | 140684395422848 | INFO : server started on [0.0.0.0]:10255 INFO 01-02 12:22:55 [start_utils.py:37] init func start_metric_manager : init ok WARNING 01-02 12:23:02 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-02 12:23:02 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-02 12:23:02 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-02 12:23:02 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. INFO 01-02 12:23:02 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:23:02 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:23:02 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:23:02 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:23:02 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-02 12:23:02 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-02 12:23:02 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On INFO 01-02 12:23:02 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:23:02 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:23:02 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:23:02 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:23:02 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-02 12:23:02 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. WARNING 01-02 12:23:02 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm INFO 01-02 12:23:02 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-02 12:23:03 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm INFO 01-02 12:23:03 [manager.py:36] pub_to_httpserver sendhwm 1000 WARNING 01-02 12:23:03 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! 2026-01-02 12:23:03 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 36414) with fd 25 2026-01-02 12:23:03 | server | 140653235951168 | INFO : welcome ('127.0.0.1', 36414) INFO 01-02 12:23:08 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-02 12:23:09 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. INFO 01-02 12:23:10 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:23:10 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:23:10 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:23:10 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:23:10 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. WARNING 01-02 12:23:10 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-02 12:23:10 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm INFO 01-02 12:23:10 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. WARNING 01-02 12:23:10 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! INFO 01-02 12:23:10 [model_rpc.py:67] Initialized RPC server for rank 0. INFO 01-02 12:23:10 [model_rpc.py:168] use ChunkedPrefillBackend INFO 01-02 12:23:11 [basemodel.py:157] Initial quantization. The default quantization method is none pid 39235 Loading model weights with 1 workers: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1.01it/s] INFO 01-02 12:23:12 [mem_utils.py:37] mode setting params: [] INFO 01-02 12:23:12 [mem_utils.py:57] Model kv cache using mode normal INFO 01-02 12:23:12 [mem_manager.py:84] 69.38735313415528 GB space is available after load the model weight INFO 01-02 12:23:12 [mem_manager.py:84] 0.109375 MB is the size of one token kv cache INFO 01-02 12:23:12 [mem_manager.py:84] 649624 is the profiled max_total_token_num with the mem_fraction 0.9 INFO 01-02 12:23:12 [mem_manager.py:84] warming up: 0%| | 0/12 [00:00 INFO 01-02 12:23:45 [manager.py:196] use req queue ChunkedPrefillQueue INFO 01-02 12:23:45 [start_utils.py:37] init func start_router_process : init ok INFO 01-02 12:23:45 [start_utils.py:37] init func start_detokenization_process : init ok INFO 01-02 12:23:45 [api_start.py:58] start process pid 30307 INFO 01-02 12:23:45 [api_start.py:59] http server pid 54746 [2026-01-02 12:23:45 +0800] [54746] [INFO] Starting gunicorn 23.0.0 [2026-01-02 12:23:45 +0800] [54746] [INFO] Listening at: http://0.0.0.0:8000 (54746) [2026-01-02 12:23:45 +0800] [54746] [INFO] Using worker: uvicorn.workers.UvicornWorker [2026-01-02 12:23:45 +0800] [54966] [INFO] Booting worker with pid: 54966 WARNING 01-02 12:23:51 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-02 12:23:51 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. INFO 01-02 12:23:52 [__init__.py:36] Available plugins for group vllm.platform_plugins: INFO 01-02 12:23:52 [__init__.py:38] - musa -> vllm_musa:register INFO 01-02 12:23:52 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load. INFO 01-02 12:23:52 [__init__.py:232] Platform plugin musa is activated WARNING 01-02 12:23:52 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-02 12:23:52 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-02 12:23:52 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-02 12:23:52 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm [2026-01-02 12:23:52 +0800] [54966] [INFO] Started server process [54966] [2026-01-02 12:23:52 +0800] [54966] [INFO] Waiting for application startup. INFO 01-02 12:23:52 [api_http.py:359] server start up 2026-01-02 12:23:53 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 55128) with fd 26 2026-01-02 12:23:53 | server | 140653227558464 | INFO : welcome ('127.0.0.1', 55128) 2026-01-02 12:23:53 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 55144) with fd 27 2026-01-02 12:23:53 | server | 140653219165760 | INFO : welcome ('127.0.0.1', 55144) INFO 01-02 12:23:54 [req_id_generator.py:34] ReqIDGenerator init finished INFO 01-02 12:23:54 [api_http.py:363] server start up ok, loop use is [2026-01-02 12:23:54 +0800] [54966] [INFO] Application startup complete. INFO 01-02 12:23:58 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-02 12:23:58 lightllm_req_id:8 INFO 01-02 12:23:58 [manager.py:424] router recive req id 8 cost time 0.05271601676940918 s DEBUG 01-02 12:23:58 [manager.py:322] Prefill Batch: batch_id=-1, time:1767327838.6764812s req_ids:[8] DEBUG 01-02 12:23:58 [manager.py:322] INFO 01-02 12:23:58 [manager.py:55] detokenization recv req id 8 cost time 0.0744318962097168 s INFO 01-02 12:23:59 [manager.py:163] detoken release req id 8 INFO 01-02 12:23:59 [manager.py:611] X-Request-Id: X-Session-Id: start_time:2026-01-02 12:23:58 lightllm_req_id:8 first_token_cost:409.63053703308105ms total_cost_time:907.1474075317383ms,out_token_counter:17 mean_per_token_cost_time: 29.265698264626895ms prompt_token_num:4 gpu cache hit: False gpu_prompt_cache_len:0 gpu_prompt_cache_ratio:0.0 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:38158 - "POST /generate HTTP/1.1" 200 DEBUG 01-02 12:23:59 [req_manager.py:78] freed all request size 1008 DEBUG 01-02 12:23:59 [infer_batch.py:172] free a batch state: DEBUG 01-02 12:23:59 [infer_batch.py:172] radix refed token num 0 DEBUG 01-02 12:23:59 [infer_batch.py:172] radix hold token num 21 DEBUG 01-02 12:23:59 [infer_batch.py:172] mem manager can alloc token num 649603 DEBUG 01-02 12:23:59 [infer_batch.py:172] mem manager total size 649624 INFO 01-02 12:23:59 [batch.py:56] router release req id 8 INFO 01-02 12:23:59 [shm_req_manager.py:111] all shm req has been release ok ``` Signed-off-by: Xiaodong Ye --- lightllm/__init__.py | 4 +++ .../token_attention_nopad_att1.py | 3 +- lightllm/utils/device_utils.py | 31 +++++++++++++++---- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/lightllm/__init__.py b/lightllm/__init__.py index e69de29bb2..e9ba6f3041 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -0,0 +1,4 @@ +from lightllm.utils.device_utils import is_musa + +if is_musa(): + import torchada # noqa: F401 diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fecd..45de83e989 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1( ).to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) + att_value = tl.sum(q[None, :] * k, 1) + att_value = att_value.to(tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index cd48a355bd..09d7a680fe 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps): return num_sm +@lru_cache(maxsize=1) +def is_musa(): + return hasattr(torch.version, "musa") and torch.version.musa is not None + + @lru_cache(maxsize=None) def get_current_device_name(): - import torch - - if torch.cuda.is_available(): + if torch.cuda.is_available() or is_musa(): device = torch.cuda.current_device() gpu_name = torch.cuda.get_device_name(device) # 4090 trans to 4090 D @@ -103,8 +106,6 @@ def init_p2p(device_index): """ torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。 """ - import torch - num_gpus = torch.cuda.device_count() tensor = torch.zeros((1,)) tensor = tensor.to(f"cuda:{device_index}") @@ -127,8 +128,26 @@ def has_nvlink(): result = result.decode("utf-8") # Check if the output contains 'NVLink' return any(f"NV{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # nvidia-smi is not installed, assume no NVLink + return False + except subprocess.CalledProcessError: + # If there's an error while executing nvidia-smi, assume no NVLink + return False + + +def has_mtlink(): + try: + # Call mthreads-gmi to get the topology matrix + result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"]) + result = result.decode("utf-8") + # Check if the output contains 'MTLink' + return any(f"MT{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # mthreads-gmi is not installed, assume no MTLink + return False except subprocess.CalledProcessError: - # If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink + # If there's an error while executing mthreads-gmi, assume no MTLink return False From 11b4de358371e30ca2290b5c356fc16ae783f528 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 8 Dec 2025 20:52:03 +0800 Subject: [PATCH 02/43] Rl weight (#1143) Co-authored-by: sufubao --- lightllm/common/basemodel/basemodel.py | 9 - .../layer_weights/meta_weights/__init__.py | 5 +- .../{ => fused_moe}/fused_moe_weight_ep.py | 182 +++++++-- .../fused_moe_weight_ep_redundancy.py | 12 +- .../fused_moe/fused_moe_weight_tp.py | 325 ++++++++++++++++ .../gpt_oss_fused_moe_weight_tp.py | 2 +- .../meta_weights/mm_weight/__init__.py | 9 +- .../meta_weights/mm_weight/colmm_weight.py | 82 +---- .../meta_weights/mm_weight/mm_factory.py | 90 ----- .../meta_weights/mm_weight/mm_slicer.py | 18 + .../meta_weights/mm_weight/mm_weight.py | 348 +++--------------- .../meta_weights/mm_weight/rowmm_weight.py | 88 +---- .../layer_weights/meta_weights/norm_weight.py | 152 +++----- .../layer_weights/transformer_layer_weight.py | 6 +- lightllm/common/quantization/__init__.py | 5 +- lightllm/common/quantization/awq_quant.py | 139 ++++--- .../common/quantization/deepgemm_quant.py | 55 ++- lightllm/common/quantization/no_quant.py | 52 +++ .../common/quantization/quantize_method.py | 66 +++- lightllm/common/quantization/registry.py | 5 +- lightllm/common/quantization/torchao_quant.py | 9 +- .../quantization/triton_quant/triton_quant.py | 43 ++- lightllm/common/quantization/w8a8_quant.py | 100 +++-- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 49 ++- .../layer_weights/transformer_layer_weight.py | 1 + .../pre_and_post_layer_weight.py | 1 + .../layer_weights/transformer_layer_weight.py | 9 - .../layer_weights/transformer_layer_weight.py | 5 +- .../pre_and_post_layer_weight.py | 2 + .../pre_and_post_layer_weight.py | 1 + .../pre_and_post_layer_weight.py | 54 ++- .../mode_backend/redundancy_expert_manager.py | 4 +- 33 files changed, 1047 insertions(+), 885 deletions(-) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep.py (74%) rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/fused_moe_weight_ep_redundancy.py (96%) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py rename lightllm/common/basemodel/layer_weights/meta_weights/{ => fused_moe}/gpt_oss_fused_moe_weight_tp.py (99%) delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py create mode 100644 lightllm/common/quantization/no_quant.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 011f998fc0..2e4a188d0a 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -175,15 +175,6 @@ def _init_weights(self, start_layer_index=0): ) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) - self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] return def _init_mem_manager(self): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 0fa02780cb..72e0034cb8 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -1,13 +1,12 @@ from .base_weight import BaseWeight from .mm_weight import ( - MMWeightPack, MMWeightTpl, ROWMMWeight, COLMMWeight, ROWBMMWeight, ) from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight -from .fused_moe_weight_tp import create_tp_moe_wegiht_obj -from .fused_moe_weight_ep import FusedMoeWeightEP from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight +from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj +from .fused_moe.fused_moe_weight_ep import FusedMoeWeightEP diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py similarity index 74% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 7dc5b5fdcc..0923d5dea0 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id -from .base_weight import BaseWeight +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight from lightllm.common.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, @@ -23,6 +23,7 @@ from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.quantization.quantize_method import WeightPack logger = init_logger(__name__) @@ -41,6 +42,7 @@ def __init__( network_config: Dict[str, Any], layer_num: int, quant_cfg=None, + hidden_size: Optional[int] = None, ) -> None: super().__init__() @@ -62,6 +64,7 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.n_routed_experts = n_routed_experts self.data_type_ = data_type + self.hidden_size = hidden_size global_world_size = get_global_world_size() self.global_rank_ = get_global_rank() @@ -78,6 +81,7 @@ def __init__( assert self.n_routed_experts % global_world_size == 0 self.ep_n_routed_experts = self.n_routed_experts // global_world_size ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num + self.ep_load_expert_num = ep_load_expert_num self.experts_up_projs = [None] * ep_load_expert_num self.experts_gate_projs = [None] * ep_load_expert_num self.experts_up_proj_scales = [None] * ep_load_expert_num @@ -105,6 +109,51 @@ def __init__( # auto update redundancy expert vars self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert + # Pre-allocate memory if hidden_size is provided + if self.hidden_size is not None: + self._create_weight() + + def _create_weight(self): + """Pre-allocate GPU memory for fused MoE weights""" + if self.hidden_size is None: + return + + total_expert_num = self.ep_load_expert_num + # We need to determine intermediate size from network config or use a default + # This will be updated when first weight is loaded if needed + intermediate_size = getattr(self, "intermediate_size", None) + if intermediate_size is None: + # Default fallback - this will be corrected during load + intermediate_size = self.hidden_size * 4 + + device_id = get_current_device_id() + + if not self.quantized_weight and self.quant_method is not None: + # Quantized weights + w1_pack = self.quant_method.create_weight( + total_expert_num * intermediate_size * 2, self.hidden_size, dtype=self.data_type_, device_id=device_id + ) + self.w1[0] = w1_pack.weight.view(total_expert_num, intermediate_size * 2, self.hidden_size) + self.w1[1] = w1_pack.weight_scale.view(total_expert_num, intermediate_size * 2, self.hidden_size) + + w2_pack = self.quant_method.create_weight( + total_expert_num * self.hidden_size, intermediate_size, dtype=self.data_type_, device_id=device_id + ) + self.w2[0] = w2_pack.weight.view(total_expert_num, self.hidden_size, intermediate_size) + self.w2[1] = w2_pack.weight_scale.view(total_expert_num, self.hidden_size, intermediate_size) + else: + # Regular weights + self.w1[0] = torch.empty( + (total_expert_num, intermediate_size * 2, self.hidden_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + self.w2[0] = torch.empty( + (total_expert_num, self.hidden_size, intermediate_size), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + def experts( self, input_tensor, @@ -422,12 +471,12 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale + qw1_pack = self.quant_method.quantize(w1) + qw2_pack = self.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = self._cuda(w1) self.w2[0] = self._cuda(w2) @@ -469,38 +518,74 @@ def _fuse_weight_scale(self): def load_hf_weights(self, weights): n_expert_ep = self.ep_n_routed_experts - # tp to ep here + + # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) + # Get weight shapes from first expert to determine intermediate size + first_expert_idx = 0 + n_expert_ep * self.global_rank_ + w1_weight_name = f"{self.weight_prefix}.{first_expert_idx}.{self.w1_weight_name}.weight" + if w1_weight_name in weights: + intermediate_size = weights[w1_weight_name].shape[0] + self.intermediate_size = intermediate_size + + # Re-create weights with correct size if needed + if self.w1[0].shape[1] != intermediate_size * 2: + self._create_weight() + + # Load regular experts for i_experts_ep in range(n_expert_ep): i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[i_experts_ep] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[i_experts_ep] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[i_experts_ep] = weights[w2_weight] - - # Load weight parameters for redundant experts + self._copy_expert_weights(i_experts_ep, i_experts, weights) + + # Load redundant experts for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - if w1_weight in weights: - self.experts_gate_projs[n_expert_ep + i] = weights[w1_weight] - if w3_weight in weights: - self.experts_up_projs[n_expert_ep + i] = weights[w3_weight] - if w2_weight in weights: - self.w2_list[n_expert_ep + i] = weights[w2_weight] + self._copy_expert_weights(n_expert_ep + i, redundant_expert_id, weights) if self.quantized_weight: - self._load_weight_scale(weights) - self._fuse() + self._load_weight_scale_direct(weights) + + def _copy_expert_weights(self, target_idx, expert_id, weights): + """Copy a single expert's weights to pre-allocated GPU memory""" + w1_weight = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.weight" + w2_weight = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.weight" + w3_weight = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.weight" + + intermediate_size = self.intermediate_size + + if w1_weight in weights and w3_weight in weights: + # Combine gate and up projections into w1 + gate_weight = weights[w1_weight] # [intermediate_size, hidden_size] + up_weight = weights[w3_weight] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + if not self.quantized_weight and self.quant_method is not None: + # Quantized path + combined_cpu = torch.empty((intermediate_size * 2, self.hidden_size), dtype=gate_weight.dtype) + combined_cpu[:intermediate_size, :] = gate_weight + combined_cpu[intermediate_size:, :] = up_weight + quantized_pack = self.quant_method.quantize(combined_cpu) + self.w1[0][target_idx].copy_(quantized_pack.weight.view(intermediate_size * 2, self.hidden_size)) + if quantized_pack.weight_scale is not None: + self.w1[1][target_idx].copy_( + quantized_pack.weight_scale.view(intermediate_size * 2, self.hidden_size) + ) + else: + # Regular path + self.w1[0][target_idx, :intermediate_size, :].copy_(gate_weight) + self.w1[0][target_idx, intermediate_size:, :].copy_(up_weight) + + if w2_weight in weights: + # Copy w2 (down projection) + w2_weight_tensor = weights[w2_weight] # [hidden_size, intermediate_size] - already the correct shape + if not self.quantized_weight and self.quant_method is not None: + quantized_pack = self.quant_method.quantize(w2_weight_tensor) + self.w2[0][target_idx].copy_(quantized_pack.weight) + if quantized_pack.weight_scale is not None: + self.w2[1][target_idx].copy_(quantized_pack.weight_scale) + else: + self.w2[0][target_idx].copy_(w2_weight_tensor) def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: n_expert_ep = self.ep_n_routed_experts @@ -530,6 +615,41 @@ def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: if w2_scale in weights: self.w2_scale_list[n_expert_ep + i] = weights[w2_scale] + def _load_weight_scale_direct(self, weights: Dict[str, torch.Tensor]) -> None: + """Load weight scales directly to pre-allocated GPU memory""" + n_expert_ep = self.ep_n_routed_experts + + # Load regular expert scales + for i_experts_ep in range(n_expert_ep): + i_experts = i_experts_ep + n_expert_ep * self.global_rank_ + self._copy_expert_scales(i_experts_ep, i_experts, weights) + + # Load redundant expert scales + for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): + self._copy_expert_scales(n_expert_ep + i, redundant_expert_id, weights) + + def _copy_expert_scales(self, target_idx, expert_id, weights): + """Copy a single expert's weight scales to pre-allocated GPU memory""" + w1_scale = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.{self.weight_scale_suffix}" + w2_scale = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.{self.weight_scale_suffix}" + w3_scale = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.{self.weight_scale_suffix}" + + intermediate_size = self.intermediate_size + + if w1_scale in weights and w3_scale in weights: + # Combine gate and up projection scales into w1 scale + gate_scale = weights[w1_scale] # [intermediate_size, hidden_size] + up_scale = weights[w3_scale] # [intermediate_size, hidden_size] + + # Copy to pre-allocated memory + self.w1[1][target_idx, :intermediate_size, :].copy_(gate_scale) + self.w1[1][target_idx, intermediate_size:, :].copy_(up_scale) + + if w2_scale in weights: + # Copy w2 scale (down projection) + w2_scale_tensor = weights[w2_scale] # [hidden_size, intermediate_size] + self.w2[1][target_idx].copy_(w2_scale_tensor) + def _cuda(self, cpu_tensor): device_id = get_current_device_id() if self.quantized_weight: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py similarity index 96% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py index b53200d4c8..933a94f78c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py @@ -102,12 +102,12 @@ def _fuse(self): inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self._ep_w.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self._ep_w.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale + qw1_pack = self._ep_w.quant_method.quantize(w1) + qw2_pack = self._ep_w.quant_method.quantize(w2) + self.w1[0] = qw1_pack.weight + self.w1[1] = qw1_pack.weight_scale + self.w2[0] = qw2_pack.weight + self.w2[1] = qw2_pack.weight_scale else: self.w1[0] = w1 self.w2[0] = w2 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py new file mode 100644 index 0000000000..bf7b218b71 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -0,0 +1,325 @@ +import os +import torch +import threading +from typing import Tuple, List, Dict, Any, Union, Callable +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id, get_dp_world_size +from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( + get_row_slice_mixin, + get_col_slice_mixin, +) + + +def create_tp_moe_wegiht_obj( + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, +) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: + quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + if quant_method is not None and quant_method.method_name == "awq_marlin": + return FusedAWQMARLINMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + else: + return FusedMoeWeightTP( + gate_proj_name=gate_proj_name, + down_proj_name=down_proj_name, + up_proj_name=up_proj_name, + e_score_correction_bias_name=e_score_correction_bias_name, + weight_prefix=weight_prefix, + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + split_inter_size=split_inter_size, + data_type=data_type, + network_config=network_config, + layer_num=layer_num, + quant_cfg=quant_cfg, + ) + + +class FusedMoeWeightTP(BaseWeight): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + num_fused_shared_experts: int, + split_inter_size: int, + data_type: torch.dtype, + network_config: Dict[str, Any], + layer_num: int, + quant_cfg: Quantcfg = None, + ) -> None: + super().__init__() + self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") + self.quantized_weight = quant_cfg.quantized_weight + if self.quant_method.method_name != "none": + self.weight_scale_suffix = self.quant_method.weight_scale_suffix + + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.split_inter_size = split_inter_size + self.data_type_ = data_type + self.hidden_size = network_config.get("hidden_size") + self.tp_rank_ = get_current_rank_in_dp() + self.e_score_correction_bias = None + self.scoring_func = network_config.get("scoring_func", "softmax") + self.row_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self.col_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + ) + self._create_weight() + + def _create_weight(self): + total_expert_num = self.n_routed_experts + intermediate_size = self.split_inter_size + device_id = get_current_device_id() + + # Create e_score_correction_bias + if self.e_score_correction_bias is not None: + self.e_score_correction_bias = torch.empty( + (total_expert_num,), + dtype=self.data_type_, + device=f"cuda:{device_id}", + ) + + self.w13: WeightPack = self.quant_method.create_weight( + out_dim=intermediate_size * 2, + in_dim=self.hidden_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + self.w2: WeightPack = self.quant_method.create_weight( + out_dim=self.hidden_size, + in_dim=intermediate_size, + dtype=self.data_type_, + device_id=device_id, + num_experts=total_expert_num, + ) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w13, w13_scale = self.w13.weight, self.w13.weight_scale + w2, w2_scale = self.w2.weight, self.w2.weight_scale + use_fp8_w8a8 = self.quant_method.method_name != "none" + + from lightllm.common.fused_moe.grouped_fused_moe import fused_experts + + fused_experts( + hidden_states=input_tensor, + w1=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + return + + def _cuda(self, cpu_tensor): + device_id = get_current_device_id() + if self.quantized_weight: + return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(device_id) + + def verify_load(self): + return True + + def load_hf_weights(self, weights): + # Load bias + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + + # Load each expert with TP slicing + for i_experts in range(self.n_routed_experts): + self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) + if self.w13.weight_scale is not None: + self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) + if self.w13.weight_zero_point is not None: + self._load_expert( + i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix + ) + + def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, weight_pack, start_idx) + else: + self.quant_method.load_weight(weight, weight_pack, start_idx) + + def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): + w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" + w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" + w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" + intermediate_size = self.split_inter_size + load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + if w3_weight in weights: + load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) + + load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) + if w2_weight in weights: + load_func(slice_func(weights[w2_weight]), self.w2.get_expert(expert_idx), start_idx=0) + + def _get_load_and_slice_func(self, type: str, is_row: bool = True): + if is_row: + slicer = self.row_slicer + else: + slicer = self.col_slicer + if type == "weight": + return self._load_weight_func, slicer._slice_weight + elif type == "weight_scale": + return getattr(self.quant_method, "load_weight_scale"), slicer._slice_weight_scale + elif type == "weight_zero_point": + return getattr(self.quant_method, "load_weight_zero_point"), slicer._slice_weight_zero_point + + +class FusedAWQMARLINMoeWeightTP(FusedMoeWeightTP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + ) + + self.workspace = marlin_make_workspace_new(self.w13.weight.device, 4) + + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + from lightllm.common.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=self.e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=self.scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts - self.num_fused_shared_experts, + end=self.n_routed_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + + w1, w1_scale, w1_zero_point = self.w13.weight, self.w13.weight_scale, self.w13.weight_zero_point + w2, w2_scale, w2_zero_point = self.w2.weight, self.w2.weight_scale, self.w2.weight_zero_point + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1, + w2, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + + return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py similarity index 99% rename from lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index df72cc6208..9d79ff7c25 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -3,7 +3,7 @@ import threading from typing import Optional, Tuple, List, Dict, Any -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_tp import FusedMoeWeightTP from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 63605b1774..34d989b01f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,10 +1,5 @@ from .mm_weight import ( - MMWeightPack, MMWeightTpl, ) -from .mm_factory import ( - MMWeight, - ROWMMWeight, - ROWBMMWeight, - COLMMWeight, -) +from .rowmm_weight import ROWMMWeight, ROWBMMWeight +from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index 281f30f022..bf73b9ad89 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -1,19 +1,19 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin +from .mm_slicer import get_col_slice_mixin -class StandardCOLMMWeight(MMWeightTpl): +class COLMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -22,6 +22,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, @@ -29,74 +31,6 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQCOLMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - # 注意这里不是错误,因为awq的weight是按inxout存的 - self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINCOLMMWeight(AWQCOLMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -COLMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight, - "awq": AWQCOLMMWeight, - "awq_marlin": AWQMARLINCOLMMWeight, -} diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py deleted file mode 100644 index 464de84413..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_factory.py +++ /dev/null @@ -1,90 +0,0 @@ -from lightllm.common.quantization import Quantcfg -from lightllm.common.quantization.quantize_method import QuantizationMethod -from typing import Type, Union, Dict -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, - BMMWeightTpl, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import ( - StandardROWMMWeight, - UnquantizedROWBMMWeight, - ROWMM_WEIGHT_CLS_MAP, -) -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import ( - StandardCOLMMWeight, - COLMM_WEIGHT_CLS_MAP, -) - - -class MMWeight: - def __new__(cls, **kwargs): - """ - weight_names, - data_type, - bias_names, - quant_cfg, - layer_num, - name, - tp_rank, - tp_world_size, - ... - 该类主要是通过重载 __new__ 为对应的mm权重绑定量化方法,其他参数都是透传。 - """ - - quant_cfg = kwargs.pop("quant_cfg", None) - layer_num_ = kwargs.pop("layer_num", None) - name = kwargs.pop("name", None) - quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name) - # quantized_weight 本身是用来标识权重本身在文件中是否是以量化后的形式存储, - # 现在不再使用该参数,是否量化由后续的加载过程自动识别。 - kwargs["quant_method"] = quant_method - mmcls = cls._get_mmcls(quant_method) - return mmcls(**kwargs) - - @classmethod - def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod: - if quant_cfg is None: - return None, False - quant_method: QuantizationMethod = quant_cfg.get_quant_method(layer_num_, name) - if quant_method is None: - return None, False - quant_method.hf_quantization_config = quant_cfg.hf_quantization_config - quantized_weight = quant_cfg.quantized_weight - return quant_method, quantized_weight - - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod) -> Type[Union[MMWeightTpl, BMMWeightTpl]]: - raise NotImplementedError("Subclasses must implement _get_mmcls method") - - -class ROWMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardROWMMWeight - - return ROWMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardROWMMWeight, - ) - - -class ROWBMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return UnquantizedROWBMMWeight - else: - # TODO: Implement more quantization weight - raise NotImplementedError("ROWBMMWeight is not implemented") - - -class COLMMWeight(MMWeight): - @classmethod - def _get_mmcls(cls, quant_method: QuantizationMethod): - if quant_method is None: - return StandardCOLMMWeight - return COLMM_WEIGHT_CLS_MAP.get( - quant_method.method_name, - StandardCOLMMWeight, - ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index e3ef5b0ea3..e2830ab611 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -132,3 +132,21 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None): def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias / self.tp_world_size_ + + +def get_row_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return RowSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedRowSliceMixin(tp_rank, tp_world_size) + + +def get_col_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: + if quant_method_name.startswith("awq"): + return AwqQuantizedColSliceMixin(tp_rank, tp_world_size) + elif quant_method_name == "none": + return ColSliceMixin(tp_rank, tp_world_size) + else: + return QuantizedColSliceMixin(tp_rank, tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 7391454da0..92236b798b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -5,9 +5,10 @@ from dataclasses import dataclass from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg +from lightllm.common.quantization.no_quant import NoQuantization from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger from .mm_slicer import SliceMixinTpl @@ -15,53 +16,11 @@ logger = init_logger(__name__) -@dataclass -class MMWeightPack: - weight: Optional[torch.Tensor] = None - bias: Optional[torch.Tensor] = None - weight_scale: Optional[torch.Tensor] = None - weight_zero_point: Optional[torch.Tensor] = None - - has_bias: bool = False - has_weight_scale: bool = False - has_weight_zero_point: bool = False - - def is_ready(self) -> bool: - return ( - self.weight is not None - and (not self.has_bias or (self.has_bias and self.bias is not None)) - and (not self.has_weight_scale or (self.has_weight_scale and self.weight_scale is not None)) - and (not self.has_weight_zero_point or (self.has_weight_zero_point and self.weight_zero_point is not None)) - ) - - def ready_for_fused_merge(self) -> bool: - """ - 判断权重是否满足可以和其他权重进行融合cat的条件,因为可能权重是量化和非量化后的权重,所以复杂一些。 - """ - weight_ready = self.weight is not None and self.weight.dtype in [ - torch.bfloat16, - torch.float16, - torch.float32, - torch.float64, - ] - bias_ready = (self.has_bias and self.bias is not None) or (not self.has_bias) - if weight_ready and bias_ready: - return True - else: - return self.is_ready() - - def is_load_finished(self): - return ( - (self.is_ready() and self.weight.is_cuda) - and ((self.has_bias and self.bias.is_cuda) or (not self.has_bias)) - and ((self.has_weight_scale and self.weight_scale.is_cuda) or (not self.has_weight_scale)) - and ((self.has_weight_zero_point and self.weight_zero_point.is_cuda) or (not self.has_weight_zero_point)) - ) - - class MMWeightTpl(BaseWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], bias_names: Optional[Union[str, List[str]]], data_type: torch.dtype, @@ -72,6 +31,14 @@ def __init__( super().__init__(tp_rank, tp_world_size, data_type) self.lock = threading.Lock() + self.in_dim = in_dim + if isinstance(out_dims, int): + out_dims = [out_dims] + self.out_dims = out_dims + self.cusum_out_dims = [0] + for out_dim in out_dims[:-1]: + self.cusum_out_dims.append(self.cusum_out_dims[-1] + out_dim) + if isinstance(weight_names, str): weight_names = [weight_names] if isinstance(bias_names, str): @@ -82,60 +49,29 @@ def __init__( if bias_names[0] is None: bias_names = None - if quant_method is not None: - has_weight_scale = quant_method.has_weight_scale - has_weight_zero_point = quant_method.has_weight_zero_point - else: - has_weight_scale = False - has_weight_zero_point = False - # 同时存在 weight_names 和 quanted_weight_names 是为了兼容在线和离线两种加载方案 self.weight_names = weight_names - self.bias_names = bias_names - has_bias = self.bias_names is not None - - self.gen_weight_quant_param_names(quant_method=quant_method) - self.quant_method = quant_method - self.sub_child_mm_params: List[MMWeightPack] = [ - MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) - for _ in range(len(weight_names)) - ] - self.mm_param: MMWeightPack = MMWeightPack( - has_bias=has_bias, - has_weight_scale=has_weight_scale, - has_weight_zero_point=has_weight_zero_point, - ) + self.quant_method: QuantizationMethod = NoQuantization() if quant_method is None else quant_method self.param_slicer: SliceMixinTpl = None + self._create_weight() + self.gen_weight_quant_param_names(quant_method=quant_method) - self.weight_fused_dim = 0 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 0 - - self.load_finished: bool = False + def _create_weight(self): + self.bias = None + if self.bias_names is not None: + self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.mm_param: WeightPack = self.quant_method.create_weight( + in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() + ) + return def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: - if self.quant_method is not None: - return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger - ) - if out is None: - shape = (input_tensor.shape[0], self.mm_param.weight.shape[1]) - dtype = input_tensor.dtype - device = input_tensor.device - if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: - return torch.mm(input_tensor, self.mm_param.weight, out=out) - return torch.addmm(self.mm_param.bias, input_tensor, self.mm_param.weight, out=out) + return self.quant_method.apply( + input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + ) def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]): if quant_method is None: @@ -176,8 +112,6 @@ def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod return def load_hf_weights(self, weights): - if self.mm_param.is_load_finished(): - return for sub_child_index, param_name in enumerate(self.weight_names): self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) @@ -196,51 +130,8 @@ def load_hf_weights(self, weights): for sub_child_index, param_name in enumerate(self.weight_zero_point_names): self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - with self.lock: - # 如果需要fused的请求,全部ok了以后进行merge操作。, all([]) 竟然返回是True, 需要len(self.sub_child_mm_params) > 0 的额外判断。 - if len(self.sub_child_mm_params) > 0 and all(e.ready_for_fused_merge() for e in self.sub_child_mm_params): - self._fuse_weights() - self.sub_child_mm_params.clear() - - # 在线量化操作 - if ( - self.quant_method is not None - and self.mm_param.weight is not None - and self.quant_method.weight_need_quanted(self.mm_param.weight) - and self.load_finished is False - ): - logger.info(f"online quant weight names: {self.weight_names}") - quantized_weight, weight_scale, weight_zero_point = self.quant_method.quantize( - self.mm_param.weight.cuda(get_current_device_id()) - ) - self.mm_param.weight = quantized_weight - self.mm_param.weight_scale = weight_scale - self.mm_param.weight_zero_point = weight_zero_point - - # repack 操作 - if ( - self.quant_method is not None - and self.mm_param.is_ready() - and self.quant_method.params_need_repack() - and self.load_finished is False - ): - ( - self.mm_param.weight, - self.mm_param.weight_scale, - self.mm_param.weight_zero_point, - ) = self.quant_method.params_repack( - weight=self.mm_param.weight, - weight_scale=self.mm_param.weight_scale, - weight_zero_point=self.mm_param.weight_zero_point, - dtype_type=self.data_type_, - ) - - if self.mm_param.is_ready() and self.load_finished is False: - self._to_gpu_device() - self.load_finished = True - def verify_load(self) -> bool: - return self.mm_param.is_ready() + return True # 执行顺序 def _load_weight( @@ -248,7 +139,11 @@ def _load_weight( ) -> None: if param_name in weights: weight = self.param_slicer._slice_weight(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight = weight + start_idx = self.cusum_out_dims[sub_child_index] + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, self.mm_param, offset=start_idx) + else: + self.quant_method.load_weight(weight, self.mm_param, start_idx) return def _load_bias( @@ -256,7 +151,9 @@ def _load_bias( ) -> None: if param_name in weights: bias = self.param_slicer._slice_bias(weights[param_name]) - self.sub_child_mm_params[sub_child_index].bias = bias + start_idx = self.cusum_out_dims[sub_child_index] + end_idx = start_idx + bias.shape[0] + self.mm_param.bias[start_idx:end_idx].copy_(bias) return def _load_weight_scale( @@ -264,7 +161,8 @@ def _load_weight_scale( ) -> None: if param_name in weights: weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_scale = weight_scale + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_scale(weight_scale, self.mm_param, start_idx) return def _load_weight_zero_point( @@ -272,88 +170,8 @@ def _load_weight_zero_point( ) -> None: if param_name in weights: weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) - self.sub_child_mm_params[sub_child_index].weight_zero_point = weight_zero_point - return - - # weight merge - def _fuse_weights(self) -> None: - need_merge = len(self.sub_child_mm_params) > 1 - if self.mm_param.weight is None and all(p.weight is not None for p in self.sub_child_mm_params): - if need_merge: - weight = torch.cat([p.weight for p in self.sub_child_mm_params], dim=self.weight_fused_dim) - else: - weight = self.sub_child_mm_params[0].weight - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight = None - - self.mm_param.weight = weight - - if ( - self.mm_param.has_bias - and self.mm_param.bias is None - and all(p.bias is not None for p in self.sub_child_mm_params) - ): - if need_merge: - bias = torch.cat([p.bias for p in self.sub_child_mm_params], dim=self.bias_fused_dim) - else: - bias = self.sub_child_mm_params[0].bias - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.bias = None - - self.mm_param.bias = bias - - if self.mm_param.weight_scale is None and all(p.weight_scale is not None for p in self.sub_child_mm_params): - if need_merge: - weight_scale = torch.cat( - [p.weight_scale for p in self.sub_child_mm_params], dim=self.weight_scale_and_zero_point_fused_dim - ) - else: - weight_scale = self.sub_child_mm_params[0].weight_scale - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_scale = None - - self.mm_param.weight_scale = weight_scale - - if self.mm_param.weight_zero_point is None and all( - p.weight_zero_point is not None for p in self.sub_child_mm_params - ): - if need_merge: - weight_zero_point = torch.cat( - [p.weight_zero_point for p in self.sub_child_mm_params], - dim=self.weight_scale_and_zero_point_fused_dim, - ) - else: - weight_zero_point = self.sub_child_mm_params[0].weight_zero_point - - # 快速删除,防止占用显存过久 - for p in self.sub_child_mm_params: - p.weight_zero_point = None - - self.mm_param.weight_zero_point = weight_zero_point - return - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # 让 k dim 更连续,大多数split k 算法的算子可能能更快 - self.mm_param.weight = ( - self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()).transpose(0, 1) - ) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) + start_idx = self.cusum_out_dims[sub_child_index] + self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) return @@ -376,90 +194,6 @@ def bmm( out = g_cache_manager.alloc_tensor(shape, dtype, device=device) else: out = torch.empty(shape, dtype=dtype, device=device) - if self.mm_param.bias is None: + if self.bias is None: return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.mm_param.bias, input_tensor, fpweight, out=out) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - if self.quant_method is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - else: - # bmm 不需要 transpose 操作 - self.mm_param.weight = self.mm_param.weight.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class DeepGemmFP8W8A8B128MMWeight(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()).transpose(0, 1) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.cuda(get_current_device_id()).transpose(0, 1) - - assert self.mm_param.has_weight_zero_point is False - - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return - - -class AWQMMWeightTpl(MMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - bias_names: Optional[Union[str, List[str]]] = None, - data_type: torch.dtype = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - bias_names=bias_names, - data_type=data_type, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - self.weight_fused_dim = 1 - self.bias_fused_dim = 0 - self.weight_scale_and_zero_point_fused_dim = 1 - - def _to_gpu_device(self) -> None: - if self.mm_param.weight is not None: - self.mm_param.weight = self.mm_param.weight.cuda(get_current_device_id()) - if self.mm_param.weight_scale is not None: - self.mm_param.weight_scale = self.mm_param.weight_scale.to(self.data_type_).cuda(get_current_device_id()) - if self.mm_param.weight_zero_point is not None: - self.mm_param.weight_zero_point = self.mm_param.weight_zero_point.cuda(get_current_device_id()) - if self.mm_param.bias is not None: - # TODO 是不是所有的bias都需要转换为全局设置的数据类型吗,会不会影响精度 - self.mm_param.bias = self.mm_param.bias.to(self.data_type_).cuda(get_current_device_id()) - return + return torch.addbmm(self.bias, input_tensor, fpweight, out=out) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 0eebdc74d2..e53d643cec 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,20 +1,20 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - DeepGemmFP8W8A8B128MMWeight, - AWQMMWeightTpl, BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union -from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin +from .mm_slicer import get_row_slice_mixin -class StandardROWMMWeight(MMWeightTpl): +class ROWMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + out_dims: Optional[Union[int, List[int]]], weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -23,6 +23,8 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, bias_names=bias_names, data_type=data_type, @@ -30,32 +32,12 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class DeepGemmFP8W8A8B128ROWMMWeight(DeepGemmFP8W8A8B128MMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size ) - self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - return -class UnquantizedROWBMMWeight(BMMWeightTpl): +class ROWBMMWeight(BMMWeightTpl): def __init__( self, weight_names: Union[str, List[str]], @@ -73,53 +55,5 @@ def __init__( tp_rank=tp_rank, tp_world_size=tp_world_size, ) - self.param_slicer = RowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQROWMMWeight(AWQMMWeightTpl): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size) - - -class AWQMARLINROWMMWeight(AWQROWMMWeight): - def __init__( - self, - weight_names: Union[str, List[str]], - data_type: torch.dtype, - bias_names: Optional[Union[str, List[str]]] = None, - quant_method: QuantizationMethod = None, - tp_rank: int = None, - tp_world_size: int = None, - ) -> None: - super().__init__( - weight_names=weight_names, - data_type=data_type, - bias_names=bias_names, - quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, - ) - - -ROWMM_WEIGHT_CLS_MAP = { - "deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128ROWMMWeight, - "awq": AWQROWMMWeight, - "awq_marlin": AWQMARLINROWMMWeight, -} + # bmm 不支持量化运算操作 + self.param_slicer = get_row_slice_mixin(quant_method_name="none", tp_rank=tp_rank, tp_world_size=tp_world_size) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 5a595bff61..619158fa83 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -9,23 +9,36 @@ logger = init_logger(__name__) -class _NormWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type, bias_name=None): +class NormWeight(BaseWeightTpl): + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): super().__init__() + self.norm_dim = norm_dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type - self.weight: torch.Tensor = None - self.bias: Optional[torch.Tensor] = None + self.weight = None + self.bias = None + self.is_weight_ready = False + self.is_bias_ready = False + self._create_weight() + + def _create_weight(self): + device = f"cuda:{get_current_device_id()}" + self.weight = torch.empty(self.norm_dim, dtype=self.data_type_, device=device) + self.bias = ( + torch.empty(self.norm_dim, dtype=self.data_type_, device=device) if self.bias_name is not None else None + ) + + def load_hf_weights(self, weights): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.is_weight_ready = True + if self.bias_name in weights: + self.bias.copy_(weights[self.bias_name]) + self.is_bias_ready = True def verify_load(self): - load_ok = True - # Verify weight. The weight must be not None. - load_ok = load_ok and self.weight is not None - # Verify bias. If bias_name is set, it must be not None. - if self.bias_name is not None: - load_ok = load_ok and self.bias is not None - return load_ok + return self.is_weight_ready and (self.bias_name is None or self.is_bias_ready) def rmsnorm_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty @@ -36,108 +49,29 @@ def rmsnorm_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) - def layernorm_forward( - self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty - ) -> torch.Tensor: - assert input.ndim == 2 and self.weight.ndim == 1 - assert self.bias is not None - - _tout = layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps) - if out is None: - return _tout - else: - out.copy_(_tout) - return out - -class NoTpNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name=weight_name, data_type=data_type, bias_name=bias_name) - self.tp_world_size_ = 1 - self.tp_rank_ = 0 +class GEMMANormWeight(NormWeight): + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) - if self.bias_name in weights and self.bias is None: - self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + # TODO: 这里直接 +1 会不会导致精度问题? 计算时要求 (1.0 + weight.float()) ? + if self.weight_name in weights: + self.weight.copy_((weights[self.weight_name] + 1).to(self.data_type_)) + self.is_weight_ready = True -class NoTpGEMMANormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - assert self.bias_name is None - self.tp_world_size_ = 1 - self.tp_rank_ = 0 +class TpNormWeight(NormWeight): + def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): + super().__init__(norm_dim, weight_name, data_type, bias_name) def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - self.weight = (weights[self.weight_name] + 1).to(self.data_type_).cuda(get_current_device_id()) - - -class TpVitPadNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, head_num: int, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - self.head_num = head_num - - def _pad_tensor_param(self, weight: torch.Tensor): - assert weight.ndim == 1 - hidden_size = weight.shape[0] - head_dim = hidden_size // self.head_num - assert hidden_size % self.head_num == 0 - - if self.head_num % self.tp_world_size_ == 0: - return weight - else: - logger.warning(f"padding {self.weight_name} weights in TpVitPadNormWeight") - pad_head_num = self.tp_world_size_ - (self.head_num % self.tp_world_size_) - pad_dims = pad_head_num * head_dim - weight = torch.nn.functional.pad(weight, (0, pad_dims), mode="constant", value=0.0) - return weight - - def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - t_weight = weights[self.weight_name] - t_weight = self._pad_tensor_param(t_weight) - new_hidden_size = t_weight.shape[0] - split_n_embed = new_hidden_size // self.tp_world_size_ - assert new_hidden_size % self.tp_world_size_ == 0 - - start = split_n_embed * self.tp_rank_ - end = split_n_embed * (self.tp_rank_ + 1) - - self.weight = t_weight[start:end].to(self.data_type_).cuda(get_current_device_id()) - - if self.bias_name in weights and self.bias is None: - t_bias = weights[self.bias_name] - t_bias = self._pad_tensor_param(t_bias) - new_hidden_size = t_bias.shape[0] - split_n_embed = new_hidden_size // self.tp_world_size_ - assert new_hidden_size % self.tp_world_size_ == 0 - - start = split_n_embed * self.tp_rank_ - end = split_n_embed * (self.tp_rank_ + 1) - - self.bias = t_bias[start:end].to(self.data_type_).cuda(get_current_device_id()) - - -class TpHeadNormWeight(_NormWeight): - def __init__(self, weight_name, data_type, bias_name=None): - super().__init__(weight_name, data_type, bias_name) - - def load_hf_weights(self, weights): - if self.weight_name in weights and self.weight is None: - t_weight = weights[self.weight_name] - start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) - self.weight: torch.Tensor = ( - t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) - assert self.weight.ndim == 2 - - if self.bias_name in weights and self.bias is None: - t_bias = weights[self.bias_name] - start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_bias) - self.bias: torch.Tensor = ( - t_bias[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) - assert self.bias.ndim == 2 + start = self.norm_dim * self.tp_rank_ + end = self.norm_dim * (self.tp_rank_ + 1) + + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name][start:end].to(self.data_type_)) + self.is_weight_ready = True + if self.bias_name in weights: + self.bias.copy_(weights[self.bias_name][start:end].to(self.data_type_)) + self.is_bias_ready = True diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc762370..1889ceb391 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -4,6 +4,7 @@ from .base_layer_weight import BaseLayerWeight from .meta_weights import BaseWeight, MMWeightTpl from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization import Quantcfg logger = init_logger(__name__) @@ -15,7 +16,7 @@ def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): self.data_type_ = data_type self.network_config_ = network_config self.mode = mode - self.quant_cfg = quant_cfg + self.quant_cfg: Quantcfg = quant_cfg self._parse_config() self._init_weight_names() self._init_weight() @@ -41,3 +42,6 @@ def load_hf_weights(self, weights): attr.load_hf_weights(weights) elif isinstance(attr, BaseWeight): attr.load_hf_weights(weights) + + def get_quant_method(self, name): + return self.quant_cfg.get_quant_method(self.layer_num_, name) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 26f59258cd..ecf2e6d42f 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -6,6 +6,7 @@ from .triton_quant.triton_quant import * from .deepgemm_quant import * from .awq_quant import * +from .no_quant import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -78,4 +79,6 @@ def get_quant_type(self, layer_num, name): def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) - return QUANTMETHODS.get(quant_type) + quant_method = QUANTMETHODS.get(quant_type) + quant_method.hf_quantization_config = self.hf_quantization_config + return quant_method diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/awq_quant.py index 8c04cdcea9..d523cce757 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/awq_quant.py @@ -9,8 +9,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from lightllm.utils.dist_utils import get_current_device_id -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack if HAS_VLLM: awq_dequantize = vllm_ops.awq_dequantize @@ -39,16 +38,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -72,21 +72,21 @@ def __init__(self): def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias NEED_DEQUANT_WEIGHT = input_tensor.shape[:-1].numel() >= 256 if NEED_DEQUANT_WEIGHT: @@ -99,6 +99,33 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + end_idx = start_idx + weight_zero_point.shape[1] + weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) + return + @QUANTMETHODS.register("awq_marlin") class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): @@ -115,20 +142,15 @@ def __init__(self): self.vllm_quant_type = TYPE_MAP[self.nbits] self.has_weight_scale = True self.has_weight_zero_point = True + self.tile_size = 16 @property def method_name(self): return "awq_marlin" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: raise NotImplementedError("AWQ online quantization is not supported yet.") - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return True - def params_repack( self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -144,47 +166,18 @@ def params_repack( ) return weight, weight_scale, weight_zero_point - def _process_weight_after_loading(self, weight: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - self.k = weight.shape[0] - self.n = weight.shape[1] * self.pack_factor - return vllm_ops.awq_marlin_repack( - weight, - size_k=weight.shape[0], - size_n=weight.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - - def _process_weight_scale_after_loading(self, weight_scale: torch.Tensor) -> torch.Tensor: - assert self.hf_quantization_config is not None, "hf_quantization_config is not set" - group_size = self.hf_quantization_config["group_size"] - return marlin_permute_scales( - weight_scale, - size_k=weight_scale.shape[0] * group_size, - size_n=weight_scale.shape[1], - group_size=self.hf_quantization_config["group_size"], - ) - - def _process_weight_zero_point_after_loading(self, weight_zero_point: torch.Tensor) -> torch.Tensor: - return awq_to_marlin_zero_points( - weight_zero_point, - size_k=weight_zero_point.shape[0], - size_n=weight_zero_point.shape[1] * self.pack_factor, - num_bits=self.hf_quantization_config["bits"], - ) - def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point - bias = weight_pack.bias reshaped_x = input_tensor.reshape(-1, input_tensor.shape[-1]) use_atomic_add = should_use_atomic_add_reduce( @@ -219,6 +212,62 @@ def apply( out.add_(bias) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + self.n = out_dim + self.k = in_dim + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty( + expert_prefix + (in_dim // self.tile_size, out_dim * self.tile_size // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + device_id = get_current_device_id() + repack_weight = vllm_ops.awq_marlin_repack( + weight.cuda(device_id), + size_k=weight.shape[0], + size_n=weight.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor * self.tile_size + weight_pack.weight[:, start_idx : start_idx + repack_weight.shape[1]].copy_(repack_weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + assert self.hf_quantization_config is not None, "hf_quantization_config is not set" + group_size = self.hf_quantization_config["group_size"] + device_id = get_current_device_id() + repack_weight_scale = marlin_permute_scales( + weight_scale.cuda(device_id), + size_k=weight_scale.shape[0] * group_size, + size_n=weight_scale.shape[1], + group_size=self.hf_quantization_config["group_size"], + ) + weight_pack.weight_scale[:, start_idx : start_idx + repack_weight_scale.shape[1]].copy_(repack_weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + device_id = get_current_device_id() + repack_weight_zero_point = awq_to_marlin_zero_points( + weight_zero_point.cuda(device_id), + size_k=weight_zero_point.shape[0], + size_n=weight_zero_point.shape[1] * self.pack_factor, + num_bits=self.hf_quantization_config["bits"], + ) + start_idx = start_idx // self.pack_factor + weight_pack.weight_zero_point[:, start_idx : start_idx + repack_weight_zero_point.shape[1]].copy_( + repack_weight_zero_point + ) + return + # adapted from # https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 7dbd3806b9..86dd9b5729 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -1,5 +1,6 @@ import os import torch +from torch.types import Device from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F @@ -9,8 +10,8 @@ ) from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack + try: HAS_DEEPGEMM = True import deep_gemm @@ -26,17 +27,17 @@ def __init__(self): self.cache_manager = g_cache_manager assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" - def quantize(self, weight: torch.Tensor): - """ """ - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -60,26 +61,30 @@ def __init__(self): def method_name(self): return "deepgemm-fp8w8a8-b128" - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - weight, scale = weight_quant(weight, self.block_size) - return weight, scale, None + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale input_scale = None alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty m, k = input_tensor.shape - n = qweight.shape[1] + n = qweight.shape[0] if input_scale is None: qinput_tensor, input_scale = per_token_group_quant_fp8( input_tensor, @@ -92,9 +97,35 @@ def apply( if out is None: out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight.t(), weight_scale.t()), out) + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return + def _deepgemm_fp8_nt(a_tuple, b_tuple, out): if HAS_DEEPGEMM: diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py new file mode 100644 index 0000000000..f342607c10 --- /dev/null +++ b/lightllm/common/quantization/no_quant.py @@ -0,0 +1,52 @@ +from .quantize_method import QuantizationMethod, WeightPack +from .registry import QUANTMETHODS +import torch +from typing import Optional + + +@QUANTMETHODS.register("none") +class NoQuantization(QuantizationMethod): + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + weight = weight_pack.weight.t() + if out is None: + shape = (input_tensor.shape[0], weight.shape[1]) + dtype = input_tensor.dtype + device = input_tensor.device + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + else: + out = torch.empty(shape, dtype=dtype, device=device) + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype).cuda(device_id) + return WeightPack(weight=weight, weight_scale=None, weight_zero_point=None) + + def weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + return + + @property + def method_name(self): + return "none" + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0], :].copy_(weight) + return diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 9b629bcaf1..77e59465ee 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,38 +1,58 @@ import torch from abc import ABC, abstractmethod +from dataclasses import dataclass from lightllm.utils.dist_utils import get_current_device_id -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +@dataclass +class WeightPack: + weight: Optional[torch.Tensor] = None + weight_scale: Optional[torch.Tensor] = None + weight_zero_point: Optional[torch.Tensor] = None + + def get_expert(self, expert_idx: int): + assert self.weight.ndim == 3, f"weight must be a 3D tensor, but got {self.weight.ndim}" + weight = self.weight[expert_idx] + weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None + weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) class QuantizationMethod(ABC): def __init__(self): super().__init__() self.device_id_ = get_current_device_id() - self.weight_suffix = None + self.weight_suffix = "weight" self.weight_scale_suffix = None self.weight_zero_point_suffix = None self.act_scale_suffix = None self.has_weight_scale: bool = None self.has_weight_zero_point: bool = None + self.group_size: int = -1 # -1表示不分组即per-channel量化,其他表示分组大小 + self.pack_factor: int = 1 + # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None @abstractmethod - def quantize(self, weights: torch.Tensor): + def quantize( + self, + weight: torch.Tensor, + output: WeightPack, + offset: int = 0, + ) -> None: pass @abstractmethod def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", - bias: Optional[torch.Tensor] = None, + weight_pack: "WeightPack", out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: pass @@ -41,20 +61,26 @@ def apply( def method_name(self): pass + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + pass + def weight_need_quanted(self, weight: torch.Tensor) -> bool: # 判断一个 weight 是否需要进行量化操作。 return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] - def params_need_repack(self) -> bool: - """ - 用于说明是否需要对量化后的权重进行repack操作,目前只有awq支持 - """ - return False - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - return weight, weight_scale, weight_zero_point + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight" + ) + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight scale" + ) + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + raise NotImplementedError( + f"quantization method {self.method_name} is not supported to load offline quantized weight zero point" + ) diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index 674a22b60f..e9b4073987 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -1,5 +1,4 @@ from .quantize_method import QuantizationMethod -from typing import Type class QuantMethodFactory: @@ -17,9 +16,7 @@ def decorator(cls): return decorator - def get(self, key, *args, **kwargs) -> Type[QuantizationMethod]: - if key == "none": - return None + def get(self, key, *args, **kwargs) -> "QuantizationMethod": quant_method_class = self._quant_methods.get(key) if not quant_method_class: raise ValueError(f"QuantMethod '{key}' not supported.") diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py index ba4115b1d9..d1db65b35a 100644 --- a/lightllm/common/quantization/torchao_quant.py +++ b/lightllm/common/quantization/torchao_quant.py @@ -5,8 +5,7 @@ import torch.nn.functional as F from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from .quantize_method import WeightPack try: HAS_TORCH_AO = True @@ -34,17 +33,17 @@ def __init__(self): assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" self.quant_func = None - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: """ """ dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) quantize_(dummy_linear, self.quant_func) - return dummy_linear.weight, None, None + return WeightPack(weight=dummy_linear.weight, weight_scale=None, weight_zero_point=None) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py index 410f925a5e..9f6a7bee27 100644 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ b/lightllm/common/quantization/triton_quant/triton_quant.py @@ -7,8 +7,7 @@ from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 from typing import TYPE_CHECKING, Optional -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack +from lightllm.common.quantization.quantize_method import WeightPack class TritonBaseQuantizationMethod(QuantizationMethod): @@ -18,16 +17,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> WeightPack: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -44,17 +44,18 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: # TODO block-wise quant kernel - pass + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -83,3 +84,29 @@ def apply( dtype=input_tensor.dtype, ) return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py index 31004de4e3..1728e799db 100644 --- a/lightllm/common/quantization/w8a8_quant.py +++ b/lightllm/common/quantization/w8a8_quant.py @@ -11,8 +11,8 @@ from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -if TYPE_CHECKING: - from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightPack + +from .quantize_method import WeightPack if HAS_LIGHTLLM_KERNEL: @@ -38,16 +38,17 @@ def __init__(self): self.cache_manager = g_cache_manager - def quantize(self, weight: torch.Tensor): - pass + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + raise NotImplementedError("Not implemented") def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -55,6 +56,11 @@ def apply( def method_name(self): return "w8a8-base" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + raise NotImplementedError("Not implemented") + @QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) class w8a8QuantizationMethod(BaseQuantizationMethod): @@ -63,27 +69,27 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): - if isinstance(weight, tuple): - return (weight[0].transpose(0, 1).cuda(self.device_id_),) + weight[1:] - weight = weight.float() + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 - weight = weight.transpose(0, 1) / scale.reshape(1, -1) + weight = weight / scale.reshape(-1, 1) weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - return weight.cuda(self.device_id_), scale.cuda(self.device_id_), None + output.weight[offset : offset + weight.shape[0]].copy_(weight) + output.weight_scale[offset : offset + weight.shape[0]].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_scale = None - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias input_scale = None # dynamic quantization for input tensor x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] @@ -100,6 +106,14 @@ def apply( def method_name(self): return "vllm-w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) class FP8w8a8QuantizationMethod(BaseQuantizationMethod): @@ -109,19 +123,20 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: if self.is_moe: - return self.quantize_moe(weight) + return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( - weight.contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) - return qweight.transpose(0, 1), weight_scale, None + output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) + output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) + return - def quantize_moe(self, weight: torch.Tensor): + def quantize_moe(self, weight: torch.Tensor) -> WeightPack: num_experts = weight.shape[0] - qweights = [] - weight_scales = [] qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) + weight_scales = [] for i in range(num_experts): qweight, weight_scale = scaled_fp8_quant( weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True @@ -129,19 +144,19 @@ def quantize_moe(self, weight: torch.Tensor): qweights[i] = qweight weight_scales.append(weight_scale) weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return qweights, weight_scale, None + return WeightPack(weight=qweights, weight_scale=weight_scale) def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight + qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - bias = weight_pack.bias x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) m = input_tensor.shape[0] n = qweight.shape[1] @@ -160,6 +175,14 @@ def apply( def method_name(self): return "vllm-fp8w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + @QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): @@ -170,21 +193,26 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - def quantize(self, weight: torch.Tensor): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - raise Exception("Not implemented") + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return def apply( self, input_tensor: torch.Tensor, - weight_pack: "MMWeightPack", + weight_pack: WeightPack, out: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - bias = weight_pack.bias + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor m, k = input_tensor.shape n = qweight.shape[1] @@ -213,3 +241,13 @@ def apply( @property def method_name(self): return "vllm-fp8w8a8-b128" + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index f6a841b1aa..50a4ca437e 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -2,7 +2,9 @@ import torch import numpy as np -from lightllm.common.basemodel.layer_weights.meta_weights.gpt_oss_fused_moe_weight_tp import GPTOSSFusedMoeWeightTP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gpt_oss_fused_moe_weight_tp import ( + GPTOSSFusedMoeWeightTP, +) from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 6b92272ee7..8b496bfa2a 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -24,11 +24,16 @@ def _init_weight(self): self._init_norm() def _parse_config(self): + self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] + self.head_dim = self.network_config_.get("head_dim", head_dim) + assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] self.n_inter = self.network_config_["intermediate_size"] - self.n_kv_head = self.network_config_["num_key_value_heads"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) + self.n_head = self.network_config_["num_attention_heads"] def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" @@ -57,49 +62,57 @@ def _init_weight_names(self): self._ffn_norm_bias_name = None def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.tp_q_head_num_ * self.head_dim + k_out_dim = self.tp_k_head_num_ * self.head_dim + v_out_dim = self.tp_v_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) self.kv_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[k_out_dim, v_out_dim], weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_o(self): + in_dim = self.tp_o_head_num_ * self.head_dim + out_dim = self.n_embed self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): + in_dim = self.n_embed + out_dim = self.n_inter // self.tp_world_size_ self.gate_up_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[out_dim, out_dim], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=out_dim, + out_dims=[in_dim], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08ba..4967687103 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -62,6 +62,7 @@ def _init_moe(self): layer_num=self.layer_num_, quant_cfg=self.quant_cfg, num_fused_shared_experts=0, + hidden_size=self.network_config_.get("hidden_size"), ) else: raise ValueError(f"Unsupported moe mode: {moe_mode}") diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index a8a57c02ed..ddce6a4bdf 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -4,4 +4,5 @@ class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 6962818c49..4aed0c9de8 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -11,15 +11,6 @@ def _init_weight_names(self): self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" - def _parse_config(self): - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) - self.tp_v_head_num_ = self.tp_k_head_num_ - self.tp_o_head_num_ = self.tp_q_head_num_ - head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] - self.head_dim = self.network_config_.get("head_dim", head_dim) - assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 - def _repeat_weight(self, name, weights): # for tp_world_size_ > num_key_value_heads if name not in weights: diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6f..bc4b548192 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -53,10 +53,11 @@ def _init_weight(self): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.network_config_["hidden_size"], + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=None, tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index d5bdd79a7b..b9e5270890 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -10,7 +10,9 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config, mode): super().__init__(data_type, network_config, mode) + self._create_weight() + def _create_weight(self): self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 28a26cb4b3..b61d66905b 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,3 +1,4 @@ +import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0b..562453d467 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -13,6 +13,38 @@ def __init__(self, data_type, network_config, mode): self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] self.llm_hidden_size = self.network_config_["llm_hidden_size"] + self._create_weight() + return + + def _create_weight(self): + split_indexes = np.linspace(0, self.embed_dim, self.tp_world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + split_embed_dim = split_end - split_start + + # Pre-allocate memory for vision model weights + self.class_embedding = torch.empty((1, 1, split_embed_dim), dtype=self.data_type_).cuda() + self.position_embedding = torch.empty( + (1, 197, split_embed_dim), dtype=self.data_type_ + ).cuda() # 197 = (224//16)^2 + 1 + self.patch_embedding_weight_ = torch.empty( + (split_embed_dim, 3, self.patch_size, self.patch_size), dtype=self.data_type_ + ).cuda() + self.patch_embedding_bias_ = torch.empty(split_embed_dim, dtype=self.data_type_).cuda() + + # Pre-allocate memory for adapter weights + self.layernorm_weight_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + self.layernorm_bias_ = torch.empty(self.embed_dim, dtype=self.data_type_).cuda() + + split_indexes_llm = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) + split_start_llm = split_indexes_llm[self.tp_rank_] + split_end_llm = split_indexes_llm[self.tp_rank_ + 1] + split_llm_hidden_size = split_end_llm - split_start_llm + + self.mlp1_1_weight_ = torch.empty((self.llm_hidden_size, split_llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_1_bias_ = torch.empty(split_llm_hidden_size, dtype=self.data_type_).cuda() + self.mlp1_3_weight_ = torch.empty((split_llm_hidden_size, self.llm_hidden_size), dtype=self.data_type_).cuda() + self.mlp1_3_bias_ = torch.empty(self.llm_hidden_size, dtype=self.data_type_).cuda() return def _cuda(self, cpu_tensor): @@ -40,40 +72,38 @@ def load_hf_weights(self, weights): split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "vision_model.embeddings.class_embedding" in weights: - self.class_embedding = self._cuda( - weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end] - ) + self.class_embedding.copy_(weights["vision_model.embeddings.class_embedding"][:, :, split_start:split_end]) if "vision_model.embeddings.position_embedding" in weights: - self.position_embedding = self._cuda( + self.position_embedding.copy_( weights["vision_model.embeddings.position_embedding"][:, :, split_start:split_end] ) if "vision_model.embeddings.patch_embedding.weight" in weights: - self.patch_embedding_weight_ = self._cuda( + self.patch_embedding_weight_.copy_( weights["vision_model.embeddings.patch_embedding.weight"][split_start:split_end, :, :, :] ) if "vision_model.embeddings.patch_embedding.bias" in weights: - self.patch_embedding_bias_ = self._cuda( + self.patch_embedding_bias_.copy_( weights["vision_model.embeddings.patch_embedding.bias"][split_start:split_end] ) if "mlp1.0.weight" in weights: - self.layernorm_weight_ = self._cuda(weights["mlp1.0.weight"]) + self.layernorm_weight_.copy_(weights["mlp1.0.weight"]) if "mlp1.0.bias" in weights: - self.layernorm_bias_ = self._cuda(weights["mlp1.0.bias"]) + self.layernorm_bias_.copy_(weights["mlp1.0.bias"]) split_indexes = np.linspace(0, self.llm_hidden_size, self.tp_world_size_ + 1, dtype=np.int64) split_start = split_indexes[self.tp_rank_] split_end = split_indexes[self.tp_rank_ + 1] if "mlp1.1.weight" in weights: - self.mlp1_1_weight_ = self._cuda(weights["mlp1.1.weight"][split_start:split_end, :]).t() + self.mlp1_1_weight_.copy_(weights["mlp1.1.weight"][split_start:split_end, :].t()) if "mlp1.1.bias" in weights: - self.mlp1_1_bias_ = self._cuda(weights["mlp1.1.bias"][split_start:split_end]) + self.mlp1_1_bias_.copy_(weights["mlp1.1.bias"][split_start:split_end]) if "mlp1.3.weight" in weights: - self.mlp1_3_weight_ = self._cuda(weights["mlp1.3.weight"][:, split_start:split_end]).t() + self.mlp1_3_weight_.copy_(weights["mlp1.3.weight"][:, split_start:split_end].t()) if "mlp1.3.bias" in weights: - self.mlp1_3_bias_ = self._cuda(weights["mlp1.3.bias"]) + self.mlp1_3_bias_.copy_(weights["mlp1.3.bias"]) return diff --git a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py index 811d39a729..e3a71379d2 100644 --- a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py @@ -8,10 +8,10 @@ import json from typing import List from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep_redundancy import ( +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep_redundancy import ( FusedMoeWeightEPAutoRedundancy, ) -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_ep import FusedMoeWeightEP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep import FusedMoeWeightEP from lightllm.utils.envs_utils import get_env_start_args, get_redundancy_expert_update_interval from lightllm.utils.envs_utils import get_redundancy_expert_update_max_load_count from lightllm.utils.envs_utils import get_redundancy_expert_num From 04b214b5d879a2a2973696d12e6831aa2536c708 Mon Sep 17 00:00:00 2001 From: sangchengmeng <101796078+SangChengC@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:35:04 +0800 Subject: [PATCH 03/43] fix-internvl (#1171) --- .../models/qwen3_vl/layer_infer/transformer_layer_infer.py | 1 + lightllm/models/vit/model.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 175340a77b..17ce4b7693 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -24,6 +24,7 @@ class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): def __init__(self, layer_num, network_config, mode=[]): super().__init__(layer_num, network_config, mode) + self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdfe..b8e6eaf929 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -185,7 +185,7 @@ def encode(self, images: List[ImageItem]): else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - cur_num = img_tensors[-1].shape[0] + cur_num = img.token_num valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num @@ -195,7 +195,7 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) pixel_values = imgs.cuda().to(dtype=self.data_type) all_img_embeds = self.forward(pixel_values) - return all_img_embeds, uuids, valid_ids + return all_img_embeds.view(-1, all_img_embeds.shape[-1]), uuids, valid_ids def cuda(self): return self From b4858226296a130c8392d3226356679dfd8041be Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Fri, 9 Jan 2026 23:35:58 +0800 Subject: [PATCH 04/43] add att backend (#1165) Co-authored-by: wangzaijun Co-authored-by: root --- .pre-commit-config.yaml | 2 +- .../CN/source/tutorial/api_server_args_zh.rst | 28 +- .../source/tutorial/deepseek_deployment.rst | 36 +- .../tutorial/multi_level_cache_deployment.rst | 8 +- docs/CN/source/tutorial/reasoning_parser.rst | 3 +- .../EN/source/tutorial/api_server_args_zh.rst | 29 - .../source/tutorial/deepseek_deployment.rst | 33 +- .../tutorial/multi_level_cache_deployment.rst | 8 +- docs/EN/source/tutorial/reasoning_parser.rst | 3 +- .../common/basemodel/attention/__init__.py | 18 + .../common/basemodel/attention/base_att.py | 117 +++ .../basemodel/attention/create_utils.py | 80 ++ .../basemodel/attention/fa3}/__init__.py | 0 lightllm/common/basemodel/attention/fa3/fp.py | 243 ++++++ .../common/basemodel/attention/fa3/fp8.py | 221 +++++ .../common/basemodel/attention/fa3/mla.py | 229 +++++ .../attention/flashinfer}/__init__.py | 0 .../basemodel/attention/flashinfer/fp.py | 229 +++++ .../basemodel/attention/flashinfer/fp8.py | 121 +++ .../basemodel/attention/flashinfer/mla.py | 233 ++++++ .../basemodel/attention/triton}/__init__.py | 0 .../common/basemodel/attention/triton/fp.py | 275 ++++++ .../basemodel/attention/triton/int4kv.py | 170 ++++ .../basemodel/attention/triton/int8kv.py | 196 +++++ .../common/basemodel/attention/triton/mla.py | 125 +++ lightllm/common/basemodel/basemodel.py | 73 +- lightllm/common/basemodel/batch_objs.py | 5 +- lightllm/common/basemodel/cuda_graph.py | 28 +- lightllm/common/basemodel/infer_struct.py | 40 +- .../basemodel/layer_infer/post_layer_infer.py | 3 +- .../basemodel/layer_infer/pre_layer_infer.py | 3 +- .../template/post_layer_infer_template.py | 4 +- .../template/pre_layer_infer_template.py | 4 +- ...transformer_layer_infer_cohere_template.py | 136 --- .../transformer_layer_infer_template.py | 16 +- .../layer_infer/transformer_layer_infer.py | 3 +- .../pre_and_post_layer_weight.py | 3 +- .../layer_weights/transformer_layer_weight.py | 3 +- .../common/basemodel/prefill_cuda_graph.py | 2 - .../triton_kernel/alibi_att}/__init__.py | 0 .../context_flashattention_nopad.py | 0 .../alibi_att}/token_attention_nopad_att1.py | 0 .../token_attention_nopad_reduceV.py | 0 .../token_attention_nopad_softmax.py | 41 +- .../alibi_att}/token_flashattention_nopad.py | 0 .../basemodel/triton_kernel/att}/__init__.py | 0 .../triton_kernel/att/decode_att}/__init__.py | 0 .../att/decode_att/gqa}/__init__.py | 0 .../gqa/flash_decoding}/__init__.py | 0 .../gqa/flash_decoding}/gqa_flash_decoding.py | 11 +- .../gqa_flash_decoding_stage1.py | 12 +- .../gqa_flash_decoding_stage2.py | 50 +- .../flash_decoding}/gqa_flash_decoding_vsm.py | 2 +- .../gqa}/gqa_decode_flashattention_nopad.py | 0 .../att/decode_att/int4kv}/__init__.py | 0 .../int4kv/int4kv_flash_decoding_stage1.py | 200 +++++ .../int4kv/ppl_int4kv_flash_decoding.py | 50 ++ .../att/decode_att/int8kv}/__init__.py | 0 .../int8kv}/ppl_int8kv_flash_decoding.py | 13 +- .../ppl_int8kv_flash_decoding_diverse.py | 15 +- ...pl_int8kv_flash_decoding_diverse_stage1.py | 3 +- ...pl_int8kv_flash_decoding_diverse_stage3.py | 0 .../att/decode_att/mha/__init__.py | 0 .../decode_att/mha/flash_decoding/__init__.py | 0 .../mha/flash_decoding}/flash_decoding.py | 13 +- .../flash_decoding}/flash_decoding_stage1.py | 22 +- .../flash_decoding}/flash_decoding_stage2.py | 23 +- .../mha/stage3_decode_att/__init__.py | 0 .../token_attention_nopad_att1.py | 42 +- .../token_attention_nopad_reduceV.py | 48 +- .../token_attention_nopad_softmax.py | 39 +- .../token_attention_softmax_and_reducev.py | 0 .../att/decode_att/ppl_fp16/__init__.py | 0 .../ppl_fp16}/ppl_fp16_flash_decoding.py | 20 +- .../triton_kernel/att/prefill_att/__init__.py | 0 .../context_flashattention_nopad.py | 151 +--- .../triton_kernel/destindex_copy_kv.py | 130 +-- .../triton_kernel/gen_decode_params.py | 6 +- .../triton_kernel/gen_prefill_params.py | 1 + .../triton_kernel/kv_copy/__init__.py | 0 .../triton_kernel/kv_copy/mla_copy_kv.py} | 11 +- .../kv_copy/ppl_int4kv_copy_kv.py | 374 +++++++++ .../kv_copy/ppl_int8kv_copy_kv.py | 330 ++++++++ .../triton_kernel/mla_att/__init__.py | 0 .../mla_att/decode_att/__init__.py | 1 + .../mla_att/decode_att}/gqa_flash_decoding.py | 23 +- .../decode_att}/gqa_flash_decoding_config.py | 0 .../decode_att}/gqa_flash_decoding_stage1.py | 0 .../decode_att}/gqa_flash_decoding_stage2.py | 0 .../mla_att/prefill_att/__init__.py | 1 + .../context_flashattention_nopad_with_v.py | 0 .../triton_kernel/repack_kv_index.py | 0 .../common/kv_cache_mem_manager/__init__.py | 4 - .../deepseek2_fp8kv_mem_manager.py | 8 - .../deepseek2_mem_manager.py | 26 +- .../export_calibration_mem_manager.py | 22 + .../int8kv_mem_manager.py | 29 - .../kv_cache_mem_manager/mem_manager.py | 17 +- .../common/kv_cache_mem_manager/mem_utils.py | 35 +- .../offline_fp8_quant_mem_manager.py | 21 +- .../ppl_int4kv_mem_manager.py | 24 +- .../ppl_int8kv_mem_manager.py | 24 +- lightllm/models/__init__.py | 2 - .../bloom/layer_infer/post_layer_infer.py | 4 +- .../bloom/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 82 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/bloom/model.py | 6 + .../layer_infer/transformer_layer_infer.py | 28 - .../pre_and_post_layer_weight.py | 20 - .../layer_weights/transformer_layer_weight.py | 72 -- lightllm/models/chatglm2/model.py | 78 -- .../chatglm2/triton_kernel/rotary_emb.py | 160 ---- lightllm/models/cohere/infer_struct.py | 8 - .../cohere/layer_infer/post_layer_infer.py | 71 -- .../layer_infer/transformer_layer_infer.py | 84 -- .../pre_and_post_layer_weight.py | 25 - .../layer_weights/transformer_layer_weight.py | 25 - lightllm/models/cohere/model.py | 69 -- .../models/cohere/triton_kernels/layernorm.py | 131 --- .../cohere/triton_kernels/rotary_emb.py | 199 ----- .../deepseek2/flashattention_infer_struct.py | 65 -- .../models/deepseek2/flashinfer_struct.py | 106 --- lightllm/models/deepseek2/infer_struct.py | 15 - .../layer_infer/transformer_layer_infer.py | 565 ++----------- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/deepseek2/model.py | 50 +- .../triton_kernel/gqa_flash_decoding_fp8.py | 9 +- .../deepseek2/triton_kernel/sample_kv.py | 154 ++-- .../layer_infer/pre_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../gemma3/layer_infer/post_layer_infer.py | 4 +- .../gemma3/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 10 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 3 +- lightllm/models/gemma3/model.py | 4 - .../gemma_2b/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_infer/transformer_layer_infer.py | 131 ++- .../layer_weights/transformer_layer_weight.py | 3 +- lightllm/models/gpt_oss/model.py | 7 +- .../layer_weights/transformer_layer_weight.py | 8 +- lightllm/models/internlm/model.py | 3 - .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../pre_and_post_layer_weight.py | 12 +- .../llama/flashattention_infer_struct.py | 106 --- lightllm/models/llama/flashinfer_struct.py | 127 --- lightllm/models/llama/infer_struct.py | 1 - .../llama/layer_infer/post_layer_infer.py | 4 +- .../llama/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 782 +----------------- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 3 +- lightllm/models/llama/model.py | 34 - .../llama/triton_kernel/flash_decoding.py | 37 - .../triton_kernel/flash_decoding_stage1.py | 106 --- .../triton_kernel/flash_decoding_stage2.py | 64 -- .../llama/triton_kernel/ppl_int4kv_copy_kv.py | 138 ---- .../ppl_int4kv_flash_decoding.py | 50 -- .../llama/triton_kernel/ppl_quant_copy_kv.py | 294 ------- .../token_attention_nopad_reduceV.py | 223 ----- .../pre_and_post_layer_weight.py | 6 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/models/mistral/model.py | 7 - .../context_flashattention_nopad.py | 228 ----- .../init_att_sliding_window_info.py | 45 - .../token_attention_softmax_and_reducev.py | 132 --- .../layer_infer/post_layer_infer.py | 4 +- .../layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_infer/transformer_layer_infer.py | 5 +- .../layer_weights/transformer_layer_weight.py | 3 +- .../layer_infer/transformer_layer_infer.py | 57 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../context_flashattention_nopad.py | 433 ---------- .../phi3/triton_kernel/destindex_copy_kv.py | 192 ----- .../layer_infer/transformer_layer_infer.py | 7 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- lightllm/models/qwen2_vl/infer_struct.py | 9 +- .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/models/qwen2_vl/model.py | 3 - .../layer_infer/transformer_layer_infer.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/qwen3/model.py | 2 - .../layer_infer/transformer_layer_infer.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../qwen3_vl/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 12 +- .../pre_and_post_layer_weight.py | 4 +- lightllm/models/qwen3_vl/model.py | 3 - .../layer_infer/transformer_layer_infer.py | 5 +- .../pre_and_post_layer_weight.py | 4 +- .../transformers_layer_weight.py | 4 +- lightllm/models/qwen3_vl_moe/model.py | 3 - .../qwen_vl/layer_infer/pre_layer_infer.py | 4 +- lightllm/models/qwen_vl/model.py | 2 - .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/stablelm/model.py | 3 - .../starcoder/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 8 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- .../pre_and_post_layer_weight.py | 8 +- .../vit/layer_infer/post_layer_infer.py | 3 +- .../models/vit/layer_infer/pre_layer_infer.py | 3 +- .../layer_infer/transformer_layer_infer.py | 3 +- .../pre_and_post_layer_weight.py | 4 +- .../layer_weights/transformer_layer_weight.py | 4 +- lightllm/models/vit/model.py | 12 +- lightllm/server/api_cli.py | 64 +- lightllm/server/api_start.py | 15 - lightllm/server/core/objs/start_args_type.py | 14 +- lightllm/server/router/manager.py | 2 - .../model_infer/mode_backend/base_backend.py | 3 - .../mode_backend/chunked_prefill/impl.py | 2 +- .../mode_backend/dp_backend/impl.py | 6 +- .../generic_padded_pre_process.py | 12 +- .../mode_backend/generic_pre_process.py | 9 +- lightllm/utils/envs_utils.py | 10 +- lightllm/utils/kv_cache_utils.py | 13 - test/acc/test_deepseekr1.sh | 2 +- test/acc/test_deepseekr1_mtp.sh | 2 +- test/acc/test_deepseekr1_mtp_ep.sh | 2 +- test/acc/test_qwen2.sh | 2 +- test/acc/test_qwen3.sh | 2 +- .../benchmark/static_inference/model_infer.py | 10 - .../static_inference/model_infer_mtp.py | 2 - test/start_scripts/README.md | 1 - test/start_scripts/draft.sh | 4 +- test/start_scripts/multi_node_ep_node0.sh | 2 +- test/start_scripts/multi_node_ep_node1.sh | 2 +- test/start_scripts/multi_node_tp_node0.sh | 2 +- test/start_scripts/multi_node_tp_node1.sh | 2 +- .../multi_pd_master/pd_decode.sh | 2 +- .../multi_pd_master/pd_prefill.sh | 2 +- test/start_scripts/single_node_ep.sh | 2 +- test/start_scripts/single_node_tp.sh | 2 +- .../single_node_tp_cpu_cache_enable.sh | 2 +- .../single_pd_master/pd_decode.sh | 2 +- .../single_pd_master/pd_nixl_decode.sh | 2 +- .../single_pd_master/pd_nixl_prefill.sh | 2 +- .../single_pd_master/pd_prefill.sh | 2 +- test/test_api/test_generate_api.py | 2 +- .../test_context_flashattention_nopad.py | 104 +++ .../kv_copy/test_ppl_int4kv_copy_kv.py | 62 ++ .../kv_copy/test_ppl_int8kv_copy_kv.py | 86 ++ .../triton_kernel/test_gen_decode_params.py | 12 - .../server/core/objs/test_shm_req_manager.py | 2 - 270 files changed, 4552 insertions(+), 6087 deletions(-) create mode 100644 lightllm/common/basemodel/attention/__init__.py create mode 100644 lightllm/common/basemodel/attention/base_att.py create mode 100644 lightllm/common/basemodel/attention/create_utils.py rename lightllm/{models/bloom/triton_kernel => common/basemodel/attention/fa3}/__init__.py (100%) create mode 100644 lightllm/common/basemodel/attention/fa3/fp.py create mode 100644 lightllm/common/basemodel/attention/fa3/fp8.py create mode 100644 lightllm/common/basemodel/attention/fa3/mla.py rename lightllm/{models/chatglm2 => common/basemodel/attention/flashinfer}/__init__.py (100%) create mode 100644 lightllm/common/basemodel/attention/flashinfer/fp.py create mode 100644 lightllm/common/basemodel/attention/flashinfer/fp8.py create mode 100644 lightllm/common/basemodel/attention/flashinfer/mla.py rename lightllm/{models/chatglm2/layer_infer => common/basemodel/attention/triton}/__init__.py (100%) create mode 100644 lightllm/common/basemodel/attention/triton/fp.py create mode 100644 lightllm/common/basemodel/attention/triton/int4kv.py create mode 100644 lightllm/common/basemodel/attention/triton/int8kv.py create mode 100644 lightllm/common/basemodel/attention/triton/mla.py delete mode 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py rename lightllm/{models/chatglm2/layer_weights => common/basemodel/triton_kernel/alibi_att}/__init__.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/context_flashattention_nopad.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_att1.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_reduceV.py (100%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_attention_nopad_softmax.py (77%) rename lightllm/{models/bloom/triton_kernel => common/basemodel/triton_kernel/alibi_att}/token_flashattention_nopad.py (100%) rename lightllm/{models/chatglm2/triton_kernel => common/basemodel/triton_kernel/att}/__init__.py (100%) rename lightllm/{models/cohere => common/basemodel/triton_kernel/att/decode_att}/__init__.py (100%) rename lightllm/{models/cohere/layer_infer => common/basemodel/triton_kernel/att/decode_att/gqa}/__init__.py (100%) rename lightllm/{models/cohere/layer_weights => common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding}/__init__.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding}/gqa_flash_decoding.py (65%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding}/gqa_flash_decoding_stage1.py (96%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding}/gqa_flash_decoding_stage2.py (65%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding}/gqa_flash_decoding_vsm.py (99%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/gqa}/gqa_decode_flashattention_nopad.py (100%) rename lightllm/{models/cohere/triton_kernels => common/basemodel/triton_kernel/att/decode_att/int4kv}/__init__.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py rename lightllm/{models/mistral/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/__init__.py (100%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/ppl_int8kv_flash_decoding.py (71%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/ppl_int8kv_flash_decoding_diverse.py (88%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/ppl_int8kv_flash_decoding_diverse_stage1.py (98%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/int8kv}/ppl_int8kv_flash_decoding_diverse_stage3.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py rename lightllm/{models/phi3/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding}/flash_decoding.py (63%) rename lightllm/{models/phi3/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding}/flash_decoding_stage1.py (88%) rename lightllm/{models/phi3/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding}/flash_decoding_stage2.py (81%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py rename lightllm/{models/mistral/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att}/token_attention_nopad_att1.py (63%) rename lightllm/{models/mistral/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att}/token_attention_nopad_reduceV.py (59%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att}/token_attention_nopad_softmax.py (71%) rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att}/token_attention_softmax_and_reducev.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/decode_att/ppl_fp16}/ppl_fp16_flash_decoding.py (51%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py rename lightllm/{models/llama/triton_kernel => common/basemodel/triton_kernel/att/prefill_att}/context_flashattention_nopad.py (81%) create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py rename lightllm/{models/deepseek2/triton_kernel/destindex_copy_kv.py => common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py} (92%) create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py create mode 100644 lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py rename lightllm/{models/deepseek2/triton_kernel => common/basemodel/triton_kernel/mla_att/decode_att}/gqa_flash_decoding.py (92%) rename lightllm/{models/deepseek2/triton_kernel => common/basemodel/triton_kernel/mla_att/decode_att}/gqa_flash_decoding_config.py (100%) rename lightllm/{models/deepseek2/triton_kernel => common/basemodel/triton_kernel/mla_att/decode_att}/gqa_flash_decoding_stage1.py (100%) rename lightllm/{models/deepseek2/triton_kernel => common/basemodel/triton_kernel/mla_att/decode_att}/gqa_flash_decoding_stage2.py (100%) create mode 100644 lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py rename lightllm/{models/deepseek2/triton_kernel => common/basemodel/triton_kernel/mla_att/prefill_att}/context_flashattention_nopad_with_v.py (100%) rename lightllm/{models/deepseek2 => common/basemodel}/triton_kernel/repack_kv_index.py (100%) delete mode 100644 lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py delete mode 100755 lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py delete mode 100755 lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py delete mode 100755 lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/chatglm2/model.py delete mode 100755 lightllm/models/chatglm2/triton_kernel/rotary_emb.py delete mode 100644 lightllm/models/cohere/infer_struct.py delete mode 100644 lightllm/models/cohere/layer_infer/post_layer_infer.py delete mode 100644 lightllm/models/cohere/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py delete mode 100644 lightllm/models/cohere/layer_weights/transformer_layer_weight.py delete mode 100644 lightllm/models/cohere/model.py delete mode 100644 lightllm/models/cohere/triton_kernels/layernorm.py delete mode 100644 lightllm/models/cohere/triton_kernels/rotary_emb.py delete mode 100644 lightllm/models/deepseek2/flashattention_infer_struct.py delete mode 100644 lightllm/models/deepseek2/flashinfer_struct.py delete mode 100644 lightllm/models/llama/flashattention_infer_struct.py delete mode 100644 lightllm/models/llama/flashinfer_struct.py mode change 100755 => 100644 lightllm/models/llama/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/llama/triton_kernel/flash_decoding.py delete mode 100644 lightllm/models/llama/triton_kernel/flash_decoding_stage1.py delete mode 100644 lightllm/models/llama/triton_kernel/flash_decoding_stage2.py delete mode 100644 lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py delete mode 100644 lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py delete mode 100644 lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py delete mode 100644 lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py delete mode 100644 lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py delete mode 100644 lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py delete mode 100644 lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py delete mode 100644 lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py delete mode 100644 lightllm/models/phi3/triton_kernel/destindex_copy_kv.py create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py create mode 100644 unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py create mode 100644 unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 573ff399c5..e7e043a1f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,4 +10,4 @@ repos: rev: 6.1.0 hooks: - id: flake8 - args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] + args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231, F541'] diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index ce7a79ab97..5976fcb322 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -183,22 +183,6 @@ PD 分离模式参数 设置为 True 时,--nccl_host 必须等于 config_server_host,--nccl_port 对于 config_server 必须是唯一的, 不要为不同的推理节点使用相同的 nccl_port,这将是严重错误 -attention类型选择参数 ---------------------- - -.. option:: --mode - - 模型推理模式,可以指定多个值: - - * ``triton_int8kv``: 使用 int8 存储 kv cache,可增加 token 容量,使用 triton kernel - * ``ppl_int8kv``: 使用 int8 存储 kv cache,使用 ppl 快速 kernel - * ``ppl_fp16``: 使用 ppl 快速 fp16 解码注意力 kernel - * ``triton_flashdecoding``: 用于长上下文的 flashdecoding 模式,当前支持 llama llama2 qwen - * ``triton_gqa_attention``: 使用 GQA 的模型的快速 kernel - * ``triton_gqa_flashdecoding``: 使用 GQA 的模型的快速 flashdecoding kernel - * ``triton_fp8kv``: 使用 float8 存储 kv cache,目前仅用于 deepseek2 - - 需要阅读源代码以确认所有模型支持的具体模式 调度参数 -------- @@ -327,17 +311,9 @@ attention类型选择参数 推理后端将为解码使用微批次重叠模式 -.. option:: --enable_flashinfer_prefill - - 推理后端将为预填充使用 flashinfer 的注意力 kernel - -.. option:: --enable_flashinfer_decode - - 推理后端将为解码使用 flashinfer 的注意力 kernel - -.. option:: --enable_fa3 +.. option:: --llm_kv_type - 推理后端将为预填充和解码使用 fa3 注意力 kernel + 推理后端使用什么类型的数据存储kv cache, 可选值为 "None", "int8kv", "int4kv", "fp8kv" .. option:: --disable_cudagraph diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 071d9405ab..5d57b137c6 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,14 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `LOADWORKER=18`: 模型加载线程数,提高加载速度 - `--tp 8`: 张量并行度,使用8个GPU -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 - `--port 8088`: 服务端口 1.2 单机 DP + EP 模式 (Data Parallel + Expert Parallel) @@ -55,13 +57,15 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **参数说明:** - `MOE_MODE=EP`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 -- `--enable_fa3`: 启用 Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 +- `--llm_decode_att_backend fa3`: 启用 Flash Attention 3.0 **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -85,7 +89,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +106,8 @@ LightLLM 支持以下几种部署模式: LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +135,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +153,8 @@ LightLLM 支持以下几种部署模式: --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +203,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -219,7 +228,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -287,7 +297,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -306,7 +317,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # 如果需要启用微批次重叠,可以取消注释以下行 diff --git a/docs/CN/source/tutorial/multi_level_cache_deployment.rst b/docs/CN/source/tutorial/multi_level_cache_deployment.rst index 0446b07804..223b92dca3 100644 --- a/docs/CN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/CN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ LightLLM 的多级缓存系统采用分层设计: --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ LightLLM 的多级缓存系统采用分层设计: - ``--model_dir``: 模型文件路径,支持本地路径或 HuggingFace 模型名称 - ``--tp 8``: 张量并行度,使用 8 个 GPU 进行模型推理 - ``--graph_max_batch_size 500``: CUDA Graph 最大批次大小,影响吞吐量和显存占用 -- ``--enable_fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 +- ``--llm_prefill_att_backend fa3``: 启用 Flash Attention 3.0,提升注意力计算速度,也可以换成flashinfer后端性能更佳 - ``--mem_fraction 0.88``: GPU 显存使用比例,建议设置为 0.88及以下 CPU 缓存参数 @@ -130,7 +131,8 @@ CPU 缓存参数 --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/CN/source/tutorial/reasoning_parser.rst b/docs/CN/source/tutorial/reasoning_parser.rst index 547eb05d16..a9a0d09fe4 100644 --- a/docs/CN/source/tutorial/reasoning_parser.rst +++ b/docs/CN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 1644bbab5f..0767ae7e3b 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -183,23 +183,6 @@ Different Parallel Mode Setting Parameters When set to True, --nccl_host must equal config_server_host, --nccl_port must be unique for config_server, do not use the same nccl_port for different inference nodes, this will be a serious error -Attention Type Selection Parameters ------------------------------------- - -.. option:: --mode - - Model inference mode, can specify multiple values: - - * ``triton_int8kv``: Use int8 to store kv cache, can increase token capacity, uses triton kernel - * ``ppl_int8kv``: Use int8 to store kv cache, uses ppl fast kernel - * ``ppl_fp16``: Use ppl fast fp16 decode attention kernel - * ``triton_flashdecoding``: Flashdecoding mode for long context, currently supports llama llama2 qwen - * ``triton_gqa_attention``: Fast kernel for models using GQA - * ``triton_gqa_flashdecoding``: Fast flashdecoding kernel for models using GQA - * ``triton_fp8kv``: Use float8 to store kv cache, currently only used for deepseek2 - - Need to read source code to confirm specific modes supported by all models - Scheduling Parameters --------------------- @@ -325,18 +308,6 @@ Performance Optimization Parameters .. option:: --enable_decode_microbatch_overlap The inference backend will use microbatch overlap mode for decoding - -.. option:: --enable_flashinfer_prefill - - The inference backend will use flashinfer's attention kernel for prefill - -.. option:: --enable_flashinfer_decode - - The inference backend will use flashinfer's attention kernel for decoding - -.. option:: --enable_fa3 - - The inference backend will use fa3 attention kernel for prefill and decoding .. option:: --disable_cudagraph diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index 6098411be0..accdbc462b 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -33,12 +33,13 @@ Suitable for deploying DeepSeek-R1 model on a single H200 node. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `LOADWORKER=18`: Model loading thread count, improves loading speed - `--tp 8`: Tensor parallelism, using 8 GPUs -- `--enable_fa3`: Enable Flash Attention 3.0 +- `--llm_prefill_att_backend fa3`: Enable Flash Attention 3.0 - `--port 8088`: Service port 1.2 Single node DP + EP Mode (Data Parallel + Expert Parallel) @@ -55,13 +56,13 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 **Parameter Description:** - `MOE_MODE=EP`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp -- `--enable_fa3`: Enable Flash Attention 3.0 **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -85,7 +86,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -101,7 +103,8 @@ Suitable for deployment across multiple H200/H100 nodes. LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -129,7 +132,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ @@ -146,7 +150,8 @@ Suitable for deploying MoE models across multiple nodes. --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ @@ -195,7 +200,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8019 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip @@ -216,7 +222,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --host $host \ --port 8121 \ --nccl_port 12322 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 @@ -284,7 +291,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --tp 8 \ --dp 8 \ --nccl_port 2732 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 @@ -303,7 +311,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --nccl_port 12322 \ --tp 8 \ --dp 8 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/docs/EN/source/tutorial/multi_level_cache_deployment.rst b/docs/EN/source/tutorial/multi_level_cache_deployment.rst index bb8d943b87..6c99c351f0 100644 --- a/docs/EN/source/tutorial/multi_level_cache_deployment.rst +++ b/docs/EN/source/tutorial/multi_level_cache_deployment.rst @@ -66,7 +66,8 @@ Suitable for most scenarios, significantly increasing cache capacity while maint --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ @@ -81,7 +82,7 @@ Basic Parameters - ``--model_dir``: Model file path, supports local path or HuggingFace model name - ``--tp 8``: Tensor parallelism degree, using 8 GPUs for model inference - ``--graph_max_batch_size 500``: CUDA Graph maximum batch size, affects throughput and memory usage -- ``--enable_fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance +- ``--llm_prefill_att_backend fa3``: Enable Flash Attention 3.0 to improve attention computation speed. You can also switch to flashinfer backend for better performance - ``--mem_fraction 0.88``: GPU memory usage ratio, recommended to set to 0.88 or below CPU Cache Parameters @@ -130,7 +131,8 @@ Suitable for ultra-long text or extremely high-concurrency scenarios, providing --model_dir /path/to/Qwen3-235B-A22B \ --tp 8 \ --graph_max_batch_size 500 \ - --enable_fa3 \ + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 \ --mem_fraction 0.88 \ --enable_cpu_cache \ --cpu_cache_storage_size 400 \ diff --git a/docs/EN/source/tutorial/reasoning_parser.rst b/docs/EN/source/tutorial/reasoning_parser.rst index e76e093d63..56e61e6cd6 100644 --- a/docs/EN/source/tutorial/reasoning_parser.rst +++ b/docs/EN/source/tutorial/reasoning_parser.rst @@ -32,7 +32,8 @@ DeepSeek-R1 --model_dir /path/to/DeepSeek-R1 \ --reasoning_parser deepseek-r1 \ --tp 8 \ - --enable_fa3 + --llm_prefill_att_backend fa3 \ + --llm_decode_att_backend fa3 DeepSeek-V3 ~~~~~~~~~~~ diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py new file mode 100644 index 0000000000..80df545498 --- /dev/null +++ b/lightllm/common/basemodel/attention/__init__.py @@ -0,0 +1,18 @@ +from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend + +from .create_utils import ( + get_prefill_att_backend_class, + get_decode_att_backend_class, + get_mla_prefill_att_backend_class, + get_mla_decode_att_backend_class, +) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py new file mode 100644 index 0000000000..859d97ca84 --- /dev/null +++ b/lightllm/common/basemodel/attention/base_att.py @@ -0,0 +1,117 @@ +import torch +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Tuple, Union, Dict + +if TYPE_CHECKING: + from lightllm.common.basemodel.basemodel import TpPartBaseModel + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class BaseAttBackend: + """ + 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self, model: "TpPartBaseModel"): + self.model = model + + def create_att_prefill_state(self) -> "BasePrefillAttState": + raise NotImplementedError("not impl") + + def create_att_decode_state(self) -> "BaseDecodeAttState": + raise NotImplementedError("not impl") + + def _find_layer_index( + self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] + ) -> int: + kv_buffer = att_state.infer_state.mem_manager.kv_buffer + layer_count = len(kv_buffer) + find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} + key = min(k.data_ptr(), v.data_ptr()) + assert key in find_dict + return find_dict[key] + + +@dataclass +class AttControl: + """ + prefill_att 和 decode_att 的入参,用于控制att backend 内部的行为, 选择正确的att 实现。 + """ + + use_alibi: bool = False + tp_alibi: torch.Tensor = None + use_sliding_window: bool = False + sliding_window: Tuple[int, int] = (-1, -1) + use_att_sink: bool = False + sink_weight: torch.Tensor = None + # mla 专用传参项 + mla_prefill: bool = False + mla_prefill_dict: Dict = None + mla_decode: bool = False + mla_decode_dict: Dict = None + + +@dataclass +class BasePrefillAttState(ABC): + + backend: BaseAttBackend = None + infer_state: "InferStateInfo" = None + + @abstractmethod + def init_state(self): + pass + + @abstractmethod + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + raise NotImplementedError("not impl") + + +@dataclass +class BaseDecodeAttState(ABC): + backend: BaseAttBackend = None + infer_state: "InferStateInfo" = None + + @abstractmethod + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "BaseDecodeAttState"): + for attr_name, attr_value in vars(new_state).items(): + if isinstance(attr_value, torch.Tensor): + attr_ = getattr(self, attr_name, None) + if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): + attr_.copy_(attr_value, non_blocking=True) + + @abstractmethod + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + pass diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py new file mode 100644 index 0000000000..39e32ac635 --- /dev/null +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -0,0 +1,80 @@ +from lightllm.utils.envs_utils import get_env_start_args +from .base_att import BaseAttBackend +from .triton.fp import TritonAttBackend +from .triton.int4kv import Int4kvTritonAttBackend +from .triton.int8kv import Int8kvTritonAttBackend +from .triton.mla import MlaTritonAttBackend +from .fa3.fp import Fa3AttBackend +from .fa3.fp8 import Fp8Fa3AttBackend +from .fa3.mla import MlaFa3AttBackend +from .flashinfer.fp8 import Fp8FlashInferAttBackend +from .flashinfer.fp import FlashInferAttBackend +from .flashinfer.mla import MlaFlashInferAttBackend + +data_type_to_backend = { + "None": { + "triton": TritonAttBackend, + "fa3": Fa3AttBackend, + "flashinfer": FlashInferAttBackend, + }, + "int4kv": { + "triton": Int4kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flashinfer": Fp8FlashInferAttBackend, + }, + "int8kv": { + "triton": Int8kvTritonAttBackend, + "fa3": Fp8Fa3AttBackend, + "flashinfer": Fp8FlashInferAttBackend, + }, +} + +mla_data_type_to_backend = { + "None": { + "triton": MlaTritonAttBackend, + "fa3": MlaFa3AttBackend, + "flashinfer": MlaFlashInferAttBackend, + }, +} + + +def get_prefill_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_decode_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return data_type_to_backend[llm_dtype][backend_str] + else: + # 根据环境自动选择最好的 + raise NotImplementedError(f"error") + + +def get_mla_prefill_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_prefill_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") + + +def get_mla_decode_att_backend_class(index=0) -> BaseAttBackend: + args = get_env_start_args() + llm_dtype = args.llm_kv_type + backend_str = args.llm_decode_att_backend[index] + if backend_str != "None": + return mla_data_type_to_backend[llm_dtype][backend_str] + else: + raise NotImplementedError(f"error") diff --git a/lightllm/models/bloom/triton_kernel/__init__.py b/lightllm/common/basemodel/attention/fa3/__init__.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/__init__.py rename to lightllm/common/basemodel/attention/fa3/__init__.py diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py new file mode 100644 index 0000000000..952bb39d91 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -0,0 +1,243 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class Fa3AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "Fa3PrefillAttState": + return Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": + return Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + self.page_table = torch.empty( + (self.infer_state.batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + self.page_table.copy_( + self.infer_state.req_manager.req_to_token_indexs[ + self.infer_state.b_req_idx, : self.infer_state.max_kv_seq_len + ] + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_alibi is False + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fa3AttBackend = self.backend # for typing + + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + sinks=sink_weight, + ) + return o + + +@dataclasses.dataclass +class Fa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: Fa3AttBackend = self.backend + args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "Fa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_alibi is False + return self._normal_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight: torch.Tensor = att_control.sink_weight + else: + sink_weight = None + + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + sinks=sink_weight, + ) + return o diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py new file mode 100644 index 0000000000..3feed1ef46 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -0,0 +1,221 @@ +import dataclasses +import torch +from ..base_att import AttControl +from typing import Optional, TYPE_CHECKING +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops +from typing import Union +from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState + +if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant +else: + scaled_fp8_quant = None + + +class Fp8Fa3AttBackend(Fa3AttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_att_prefill_state(self, infer_state) -> "Fp8Fa3PrefillAttState": + return Fp8Fa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8Fa3DecodeAttState": + return Fp8Fa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8Fa3PrefillAttState(Fa3PrefillAttState): + # 临时共享变量 + mid_token_batch_ids: torch.Tensor = None + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + super().init_state() + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + self.mid_token_batch_ids = torch.repeat_interleave( + torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len + ) + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: Fp8Fa3AttBackend = self.backend # for typing + + q, q_scale = q_per_head_fp8_quant( + q, + self.infer_state.b_seq_len, + self.cu_seqlens_q, + self.mid_token_batch_ids, + ) + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + causal=True, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale, + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], + return_softmax_lse=False, + ) + return o + + +@dataclasses.dataclass +class Fp8Fa3DecodeAttState(Fa3DecodeAttState): + k_descale: torch.Tensor = None + v_descale: torch.Tensor = None + + def init_state(self): + super().init_state() + self.backend: Fp8Fa3AttBackend = self.backend + + args_mtp_step = get_env_start_args().mtp_step + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + device = self.infer_state.input_ids.device + batch_size = att_batch_size + mem_manager = self.backend.model.mem_manager + + offline_scales: torch.Tensor = mem_manager.scales + head_num = mem_manager.head_num + + # 为了减少推理计算量,在推理外部初始化k_descale和v_descale + self.k_descale = ( + offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + self.v_descale = ( + offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num) + if offline_scales is not None + else torch.ones( + (mem_manager.layer_num, batch_size, head_num), + dtype=torch.float32, + device=device, + ) + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "Fp8Fa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + + cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn) + + layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self) + + q_head_num = q.shape[1] + q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True) + o = flash_attn_with_kvcache( + q=q.view(-1, q_head_num, k_head_dim), + k_cache=cache_k, + v_cache=cache_v, + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + causal=False, + window_size=(-1, -1), + softcap=0.0, + q_descale=q_scale.view(self.infer_state.batch_size, k_head_num), + k_descale=self.k_descale[layer_index], + v_descale=self.v_descale[layer_index], + return_softmax_lse=False, + ) + return o diff --git a/lightllm/common/basemodel/attention/fa3/mla.py b/lightllm/common/basemodel/attention/fa3/mla.py new file mode 100644 index 0000000000..9a10457b12 --- /dev/null +++ b/lightllm/common/basemodel/attention/fa3/mla.py @@ -0,0 +1,229 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, TYPE_CHECKING, Tuple +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.utils.sgl_utils import flash_attn_varlen_func + + +class MlaFa3AttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.get_page_table_buffer() # init + + def get_page_table_buffer(self): + """ + 用于减少 decode graph 捕获的时候, 造成显存二次方增长的情况. + """ + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + self._shared_page_table_buffer = [ + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + torch.empty(model.graph_max_batch_size * model.graph_max_len_in_batch, dtype=torch.int32).to( + get_current_device_id() + ), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state) -> "MlaFa3PrefillAttState": + return MlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaFa3DecodeAttState": + return MlaFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFa3AttBackend = self.backend # for typing + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + + o_tensor = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=softmax_scale, + causal=True, + return_softmax_lse=False, + ) + return o_tensor + + +@dataclasses.dataclass +class MlaFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + # 在是否开启mtp 的不同模式下,其设置不同的值,可以加速算子的运行。 + decode_max_q_seq_len: int = None + + def init_state(self): + self.backend: MlaFa3AttBackend = self.backend + args_mtp_step = get_env_start_args().mtp_step + + if args_mtp_step > 0: + # 修正 mtp 在 fa3 下的输入。 + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + + model = self.backend.model + # 可以使用 cuda graph的时候从 buffer中申请 + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * model.graph_max_len_in_batch + ].reshape(att_batch_size, model.graph_max_len_in_batch) + else: + self.page_table = torch.empty( + (att_batch_size, self.infer_state.max_kv_seq_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, : self.infer_state.max_kv_seq_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + return + + def copy_for_decode_cuda_graph(self, new_state: "MlaFa3DecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + q_nope, q_rope = q + kv = k + qk_rope_head_dim = 64 + kv_lora_rank = kv.shape[-1] - qk_rope_head_dim + k_rope = kv[:, :, -qk_rope_head_dim:].view(-1, 1, 1, qk_rope_head_dim) + kv_nope = kv[:, :, :-qk_rope_head_dim].view(-1, 1, 1, kv_lora_rank) + k_descale, v_descale = None, None + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + o_tensor = flash_attn_with_kvcache( + q=q_rope, + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o_tensor diff --git a/lightllm/models/chatglm2/__init__.py b/lightllm/common/basemodel/attention/flashinfer/__init__.py similarity index 100% rename from lightllm/models/chatglm2/__init__.py rename to lightllm/common/basemodel/attention/flashinfer/__init__.py diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py new file mode 100644 index 0000000000..4c6ec0efc6 --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -0,0 +1,229 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import repack_kv_index + + +class FlashInferAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state) -> "FlashInferPrefillAttState": + return FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "FlashInferDecodeAttState": + return FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class FlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: FlashInferAttBackend = self.backend + + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32, device=device) + kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + kv_starts[:-1], + self.infer_state.max_kv_seq_len, + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._nomarl_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _nomarl_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: FlashInferAttBackend = self.backend # for typing + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q, + (k.unsqueeze(1), v.unsqueeze(1)), + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class FlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: FlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + self.kv_last_page_len_buffer = torch.full((self.infer_state.batch_size,), 1, dtype=torch.int32, device=device) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : self.infer_state.batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + self.infer_state.batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_kv_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + assert self.decode_wrapper is None + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + 1, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "FlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + 1, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._normal_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _normal_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + (k.unsqueeze(1), v.unsqueeze(1)), + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/flashinfer/fp8.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py new file mode 100644 index 0000000000..115d6985ac --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/fp8.py @@ -0,0 +1,121 @@ +import dataclasses +import torch +from ..base_att import AttControl +from .fp import FlashInferAttBackend, FlashInferPrefillAttState, FlashInferDecodeAttState + + +class Fp8FlashInferAttBackend(FlashInferAttBackend): + def __init__(self, model): + super().__init__(model=model) + self.kv_data_type = torch.float8_e4m3fn + + def create_att_prefill_state(self, infer_state) -> "Fp8FlashInferPrefillAttState": + return Fp8FlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState": + return Fp8FlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_prefill_att( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) + offline_scales = self.offline_scales + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None + self.prefill_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState): + offline_scales: torch.Tensor = None + + def init_state(self): + super().init_state() + self.offline_scales = self.infer_state.mem_manager.scales_list + + def copy_for_decode_cuda_graph(self, new_state): + return super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._fp8_decode_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _fp8_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + + k = k.unsqueeze(1).view(torch.float8_e4m3fn) + v = v.unsqueeze(1).view(torch.float8_e4m3fn) + offline_scales = self.offline_scales + layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) + + k_descale = offline_scales[layer_index][0] if offline_scales is not None else None + v_descale = offline_scales[layer_index][1] if offline_scales is not None else None + self.decode_wrapper.run( + q, + (k, v), + k_scale=k_descale, + v_scale=v_descale, + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py new file mode 100644 index 0000000000..6e52203b4f --- /dev/null +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -0,0 +1,233 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import repack_kv_index +from typing import Tuple + + +class MlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + return + + def create_att_prefill_state(self, infer_state) -> "MlaFlashInferPrefillAttState": + return MlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaFlashInferDecodeAttState": + return MlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: MlaFlashInferAttBackend = self.backend + + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.qk_nope_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att( + q=q, + k=k, + v=v, + alloc_func=alloc_func, + ) + + def _mla_prefill_att( + self, q: torch.Tensor, k: Tuple[torch.Tensor, torch.Tensor], v: torch.Tensor, alloc_func=torch.empty + ) -> torch.Tensor: + self.backend: MlaFlashInferAttBackend = self.backend # for typing + k_nope, k_rope = k + o_tensor = alloc_func((q.shape[0], q.shape[1], k_nope.shape[2]), q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class MlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: MlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ + : batch_size * self.backend.max_seq_length + ] + else: + self.kv_indices = torch.empty( + batch_size * self.backend.max_seq_length, + dtype=torch.int32, + device=device, + ) + + repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_seq_len, + self.infer_state.b_kv_start_loc, + self.infer_state.max_kv_seq_len, + self.kv_indices, + ) + assert self.decode_wrapper is None + + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + 1, + False, # causal + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + return + + def copy_for_decode_cuda_graph(self, new_state: "MlaFlashInferDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + 1, + False, # causal + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + qk_rope_head_dim = 64 + q_nope, q_rope = q + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) + assert att_control.mla_decode + + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim], + k[:, :, -qk_rope_head_dim:], + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/models/chatglm2/layer_infer/__init__.py b/lightllm/common/basemodel/attention/triton/__init__.py similarity index 100% rename from lightllm/models/chatglm2/layer_infer/__init__.py rename to lightllm/common/basemodel/attention/triton/__init__.py diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py new file mode 100644 index 0000000000..d29f15ec3b --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -0,0 +1,275 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional + + +class TritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": + return TritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": + return TritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class TritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.use_sliding_window is False and att_control.use_att_sink is False + if att_control.use_alibi: + assert att_control.tp_alibi is not None + return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + else: + return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + + def _alibi_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + out = alloc_func(q.shape, q.dtype) + + from ...triton_kernel.alibi_att.context_flashattention_nopad import context_attention_fwd + + context_attention_fwd( + q, + k, + v, + out, + self.infer_state.b_req_idx, + att_control.tp_alibi, + self.infer_state.b_q_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + self.infer_state.req_manager.req_to_token_indexs, + ) + return out + + def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + + out = alloc_func(q.shape, q.dtype) + context_attention_fwd( + q, + k, + v, + out, + self.infer_state.b_req_idx, + self.infer_state.b_q_start_loc, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + self.infer_state.req_manager.req_to_token_indexs, + ) + return out + + +@dataclasses.dataclass +class TritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "TritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert att_control.use_sliding_window is False and att_control.use_att_sink is False + if att_control.use_alibi: + assert att_control.tp_alibi is not None + return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + else: + q_head_num = q.shape[1] + k_head_num = k.shape[1] + if q_head_num == k_head_num: + return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + elif q_head_num > k_head_num: + return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + else: + raise NotImplementedError("error") + + def _alibi_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + from ...triton_kernel.alibi_att.token_flashattention_nopad import token_attention_fwd + + out = alloc_func(q.shape, q.dtype) + token_attention_fwd( + q, + k, + v, + out, + att_control.tp_alibi, + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + self.infer_state.b_kv_start_loc, + self.infer_state.b_seq_len, + self.infer_state.max_kv_seq_len, + self.infer_state.total_token_num, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + from ...triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + out = alloc_func(q.shape, q.dtype) + + token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_gqa_flash_decoding_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, + ) + + out = alloc_func(q.shape, q.dtype) + + gqa_token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_v=v, + out=out, + alloc_tensor_func=alloc_func, + ) + + return out + + def _normal_decode_gqa_flash_decoding_att_vsm( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_vsm import ( + gqa_token_decode_attention_flash_decoding_vsm, + ) + + out = alloc_func(q.shape, q.dtype) + + gqa_token_decode_attention_flash_decoding_vsm( + q=q, + k=k, + v=v, + infer_state=self.infer_state, + out=out, + alloc_tensor_func=alloc_func, + ) + return out + + def _normal_decode_gqa_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_weight, + alloc_func=torch.empty, + ): + # TODO USE , 在特定场景下比 _normal_decode_gqa_flash_decoding_att 省显存 + from ...triton_kernel.att.decode_att.gqa.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + + out = alloc_func(q.shape, q.dtype) + + gqa_decode_attention_fwd( + q=q, + k=k, + v=v, + out=out, + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_seq_len=self.infer_state.b_seq_len, + ) + return out + + def _normal_decode_stage3_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + alloc_func=torch.empty, + ): + total_token_num = self.infer_state.total_token_num + batch_size = self.infer_state.batch_size + q_head_num = q.shape[1] + head_dim = q.shape[2] + + calcu_shape1 = (batch_size, q_head_num, head_dim) + att_m_tensor = alloc_func((q_head_num, total_token_num), torch.float32) + + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_nopad_att1 import token_att_fwd + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + Req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + B_req_idx=self.infer_state.b_req_idx, + B_Start_Loc=self.infer_state.b_kv_start_loc, + B_Seqlen=self.infer_state.b_seq_len, + max_len_in_batch=self.infer_state.max_kv_seq_len, + ) + + o_tensor = alloc_func(q.shape, q.dtype) + from ...triton_kernel.att.decode_att.mha.stage3_decode_att.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o_tensor.view(calcu_shape1), + req_to_tokens=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + b_start_loc=self.infer_state.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/int4kv.py b/lightllm/common/basemodel/attention/triton/int4kv.py new file mode 100644 index 0000000000..25199dc470 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/int4kv.py @@ -0,0 +1,170 @@ +import dataclasses +import torch +from lightllm.utils.envs_utils import get_env_start_args +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple + + +class Int4kvTritonAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int4kvTritonPrefillAttState": + return Int4kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int4kvTritonDecodeAttState": + return Int4kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int4kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + + def init_state(self): + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + self.backend: Int4kvTritonAttBackend = self.backend # for typing + + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + head_dim = k.shape[2] * 2 # 2个4bit存储为一个int8, 所以维度需要翻倍,才是解量化后的精度 + k_dequant = alloc_func((total_token_num, k.shape[1], head_dim), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], head_dim), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ...triton_kernel.kv_copy.ppl_int4kv_copy_kv import dequantize_int4kv + + dequantize_int4kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + + context_attention_fwd_contiguous_kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_q_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) + return o_tensor + + +@dataclasses.dataclass +class Int4kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int4kvTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k, k_scale = k + v, v_scale = v + + return self.ppl_int4kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + + def ppl_int4kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ...triton_kernel.att.decode_att.int4kv.ppl_int4kv_flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py new file mode 100644 index 0000000000..6a795c4376 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -0,0 +1,196 @@ +import dataclasses +import torch +from lightllm.utils.envs_utils import get_env_start_args +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Optional, Tuple +from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel + + +class Int8kvTritonAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model) + self.quant_group_size: int = get_env_start_args().llm_kv_quant_group_size + + def create_att_prefill_state(self, infer_state) -> "Int8kvTritonPrefillAttState": + return Int8kvTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "Int8kvTritonDecodeAttState": + return Int8kvTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class Int8kvTritonPrefillAttState(BasePrefillAttState): + + # 用于反量化的时候使用,可以减少反量化占用的显存数量。按需使用。 + b_kv_start_loc: torch.Tensor = None + + def init_state(self): + self.b_kv_start_loc = ( + torch.cumsum(self.infer_state.b_seq_len, dim=0, dtype=self.infer_state.b_seq_len.dtype) + - self.infer_state.b_seq_len + ) + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + + self.backend: Int8kvTritonAttBackend = self.backend # for typing + + k, k_scale = k + v, v_scale = v + o = self._groupsize_quant_prefill_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + return o + + def _groupsize_quant_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + # o_tensor = alloc_func(q.shape, q.dtype, device=q.device) + # batch_size = self.infer_state.b_seq_len.shape[0] + + assert k.untyped_storage().data_ptr() == v.untyped_storage().data_ptr() + assert k_scale.untyped_storage().data_ptr() == v_scale.untyped_storage().data_ptr() + + total_token_num = self.infer_state.total_token_num + k_dequant = alloc_func((total_token_num, k.shape[1], k.shape[2]), dtype=q.dtype, device=q.device) + v_dequant = alloc_func((total_token_num, v.shape[1], v.shape[2]), dtype=q.dtype, device=q.device) + o_tensor = alloc_func(q.shape, dtype=q.dtype, device=q.device) + + max_kv_seq_len = self.infer_state.max_kv_seq_len + + from ...triton_kernel.kv_copy.ppl_int8kv_copy_kv import dequantize_int8kv + + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_seq_len=self.infer_state.b_seq_len, + b_req_idx=self.infer_state.b_req_idx, + b_kv_start_loc=self.b_kv_start_loc, + k_out=k_dequant, + v_out=v_dequant, + max_len_in_batch=max_kv_seq_len, + quant_group_size=self.backend.quant_group_size, + ) + + from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd_contiguous_kv + + context_attention_fwd_contiguous_kv( + q=q, + k=k_dequant, + v=v_dequant, + o=o_tensor, + b_start_loc=self.infer_state.b_q_start_loc, + b_kv_start_loc=self.b_kv_start_loc, + b_seq_len=self.infer_state.b_seq_len, + max_q_input_len=self.infer_state.max_q_seq_len, + b_prompt_cache_len=self.infer_state.b_ready_cache_len, + ) + return o_tensor + + +@dataclasses.dataclass +class Int8kvTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "Int8kvTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: Tuple[torch.Tensor, torch.Tensor], + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k, k_scale = k + v, v_scale = v + if enable_diverse_mode_gqa_decode_fast_kernel(): + return self.diverse_decode_att(q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, alloc_func=alloc_func) + else: + return self.ppl_mha_int8kv_decode_att( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + alloc_func=alloc_func, + ) + + def diverse_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + + from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) + + def ppl_mha_int8kv_decode_att( + self, + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + alloc_func=torch.empty, + ) -> torch.Tensor: + from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( + token_decode_attention_flash_decoding, + ) + + return token_decode_attention_flash_decoding( + q=q, + infer_state=self.infer_state, + cache_k=k, + cache_k_scale=k_scale, + cache_v=v, + cache_v_scale=v_scale, + alloc_tensor_func=alloc_func, + ) diff --git a/lightllm/common/basemodel/attention/triton/mla.py b/lightllm/common/basemodel/attention/triton/mla.py new file mode 100644 index 0000000000..8288193ad7 --- /dev/null +++ b/lightllm/common/basemodel/attention/triton/mla.py @@ -0,0 +1,125 @@ +import dataclasses +import torch +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from typing import Tuple + + +class MlaTritonAttBackend(BaseAttBackend): + def create_att_prefill_state(self, infer_state) -> "MlaTritonPrefillAttState": + return MlaTritonPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state) -> "MlaTritonDecodeAttState": + return MlaTritonDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class MlaTritonPrefillAttState(BasePrefillAttState): + def init_state(self): + pass + + def prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, + q: torch.Tensor, + k: Tuple[torch.Tensor, torch.Tensor], + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + from ...triton_kernel.mla_att.prefill_att import context_attention_fwd_with_v + + qk_rope_head_dim = 64 + q_nope, q_rope = q[:, :, :-qk_rope_head_dim], q[:, :, -qk_rope_head_dim:] + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q.device) + k_nope, k_rope = k + assert att_control.mla_prefill + softmax_scale = att_control.mla_prefill_dict["softmax_scale"] + context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o_tensor, + self.infer_state.b_q_start_loc, + self.infer_state.b1_cu_kv_seq_len, + self.infer_state.b_seq_len, + self.infer_state.b_ready_cache_len, + self.infer_state.max_q_seq_len, + softmax_scale, + ) + return o_tensor + + +@dataclasses.dataclass +class MlaTritonDecodeAttState(BaseDecodeAttState): + def init_state(self): + pass + + def copy_for_decode_cuda_graph(self, new_state: "MlaTritonDecodeAttState"): + super().copy_for_decode_cuda_graph(new_state) + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): + assert ( + att_control.use_sliding_window is False + and att_control.use_att_sink is False + and att_control.use_alibi is False + ) + assert v is None + + return self._mla_decode_att( + q=q, + k=k, + v=v, + att_control=att_control, + alloc_func=alloc_func, + ) + + def _mla_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl, + alloc_func=torch.empty, + ): + assert att_control.mla_decode + softmax_scale = att_control.mla_decode_dict["softmax_scale"] + + from ...triton_kernel.mla_att.decode_att import gqa_token_decode_attention_flash_decoding + + qk_rope_head_dim = 64 + q_nope, q_rope = q + kv = k + + out = gqa_token_decode_attention_flash_decoding( + q_nope=q_nope, + q_rope=q_rope, + kv_nope=kv[:, :, :-qk_rope_head_dim], + kv_rope=kv[:, :, -qk_rope_head_dim:], + infer_state=self.infer_state, + softmax_scale=softmax_scale, + alloc_tensor_func=alloc_func, + ) + return out diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 011f998fc0..26d51af3db 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,6 +32,8 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class +from .attention import BaseAttBackend logger = init_logger(__name__) @@ -58,7 +60,6 @@ def __init__(self, kvargs): self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") - self.mode = kvargs.get("mode", []) self.weight_dict = kvargs.get("weight_dict", None) self.finetune_config = kvargs.get("finetune_config", None) self.max_req_num = kvargs.get("max_req_num", 1000) @@ -116,9 +117,18 @@ def __init__(self, kvargs): self._init_infer_layer() self._init_some_value() self._init_custom() - self._init_inferstate_cls() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + + self._init_att_backend() + self._init_att_backend1() + + logger.info(f"use prefill att backend: {self.prefill_att_backend.__class__.__name__}") + logger.info(f"use decode att backend: {self.decode_att_backend.__class__.__name__}") + if self.prefill_att_backend1 is not None: + logger.info(f"use prefill att backend1: {self.prefill_att_backend1.__class__.__name__}") + logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}") + self._autotune_warmup() self._init_padded_req() self._init_cudagraph() @@ -144,9 +154,6 @@ def _init_config(self): self.config["vocab_size"] = self.finetune_config.vocab_size return - def _init_inferstate_cls(self): - pass - @final def _verify_must(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 @@ -162,15 +169,12 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) @@ -220,10 +224,10 @@ def _init_req_manager(self): return def _init_infer_layer(self, start_layer_index=0): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(start_layer_index, start_layer_index + self.config["n_layer"]) ] return @@ -238,6 +242,17 @@ def _init_some_value(self): self.vocab_size = self.config["vocab_size"] return + def _init_att_backend(self): + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0)(model=self) + return + + def _init_att_backend1(self): + # self.prefill_att_backend1 是给后续有模型支持不同层用不同的att模块时,保留的扩展。 + self.prefill_att_backend1: BaseAttBackend = None + self.decode_att_backend1: BaseAttBackend = None + return + def _init_cudagraph(self): self.graph = ( None if self.disable_cudagraph else CudaGraph(self.graph_max_batch_size, self.graph_max_len_in_batch) @@ -281,7 +296,6 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) infer_state.return_all_prompt_logics = self.return_all_prompt_logics infer_state.batch_size = model_input.batch_size infer_state.total_token_num = model_input.total_token_num - infer_state.max_len_in_batch = model_input.max_len_in_batch infer_state.max_q_seq_len = model_input.max_q_seq_len infer_state.max_kv_seq_len = model_input.max_kv_seq_len infer_state.max_cache_len = model_input.max_cache_len @@ -311,6 +325,19 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0) # 特殊模型,特殊模式的特定变量初始化操作。 infer_state.mtp_draft_input_hiddens = model_input.mtp_draft_input_hiddens + if infer_state.is_prefill: + infer_state.prefill_att_state = self.prefill_att_backend.create_att_prefill_state(infer_state=infer_state) + if self.prefill_att_backend1 is not None: + infer_state.prefill_att_state1 = self.prefill_att_backend1.create_att_prefill_state( + infer_state=infer_state + ) + else: + infer_state.decode_att_state = self.decode_att_backend.create_att_decode_state(infer_state=infer_state) + if self.decode_att_backend1 is not None: + infer_state.decode_att_state1 = self.decode_att_backend1.create_att_decode_state( + infer_state=infer_state + ) + return infer_state def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_size: int): @@ -323,6 +350,7 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s new_model_input = copy.copy(model_input) new_model_input.batch_size = new_batch_size new_model_input.total_token_num += padded_batch_size * 2 + new_model_input.max_kv_seq_len = max(2, model_input.max_kv_seq_len) new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_batch_size), mode="constant", value=1) new_model_input.b_req_idx = F.pad( new_model_input.b_req_idx, (0, padded_batch_size), mode="constant", value=self.req_manager.HOLD_REQUEST_ID @@ -366,7 +394,6 @@ def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle new_model_input = copy.copy(model_input) new_model_input.batch_size = model_input.batch_size + 1 new_model_input.total_token_num += padded_token_num - new_model_input.max_len_in_batch = max(padded_token_num, model_input.max_len_in_batch) new_model_input.max_q_seq_len = max(padded_token_num, model_input.max_q_seq_len) new_model_input.max_kv_seq_len = max(padded_token_num, model_input.max_kv_seq_len) new_model_input.max_cache_len = max(0, model_input.max_cache_len) @@ -464,6 +491,7 @@ def _prefill( prefill_mem_indexes_ready_event.record() infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._context_forward(infer_state) if is_padded_model_input: model_output = self._create_unpad_prefill_model_output( @@ -484,7 +512,7 @@ def _decode( model_input.b_mtp_index, ) - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): + if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_kv_seq_len): find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) @@ -495,6 +523,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True @@ -514,6 +543,7 @@ def _decode( infer_state.mem_index, ) infer_state.init_some_extra_state(self) + infer_state.init_att_state() model_output = self._token_forward(infer_state) return model_output @@ -622,6 +652,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state0.max_q_seq_len, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() infer_state1 = self._create_inferstate(model_input1, 1) init_req_to_token_indexes( @@ -634,6 +665,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod max_q_seq_len=infer_state1.max_q_seq_len, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() @@ -672,7 +704,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode assert model_input1.mem_indexes.is_cuda origin_batch_size = model_input0.batch_size - max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) + max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) @@ -688,6 +720,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(padded_model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -696,6 +730,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True @@ -724,6 +759,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state0.mem_index, ) infer_state0.init_some_extra_state(self) + infer_state0.init_att_state() + infer_state1 = self._create_inferstate(model_input1, 1) copy_kv_index_to_req( self.req_manager.req_to_token_indexs, @@ -732,6 +769,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.mem_index, ) infer_state1.init_some_extra_state(self) + infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) return model_output0, model_output1 @@ -818,7 +856,6 @@ def _check_max_len_infer(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=self.batch_max_tokens, max_q_seq_len=self.batch_max_tokens, max_kv_seq_len=self.batch_max_tokens, max_cache_len=0, @@ -895,7 +932,6 @@ def _autotune_warmup(self): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=input_len, max_q_seq_len=input_len, max_kv_seq_len=input_len, max_cache_len=0, @@ -958,7 +994,6 @@ def _init_padded_req(self): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=prefill_input_len, max_q_seq_len=prefill_input_len, max_kv_seq_len=prefill_input_len, max_cache_len=0, diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 138f084270..758c0b5194 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -11,10 +11,7 @@ class ModelInput: # 通用变量 batch_size: int total_token_num: int - max_len_in_batch: int - # 在 decode 阶段, 常规模式下, max_q_seq_len 必定是 1, - # 在 mtp 模式下,max_q_seq_len 统计的是一个请求考虑了 mtp 步数的 - # 最大长度,实际值是 max([(1 + req.mtp_step) for req in reqs]) + # 在 decode 阶段, max_q_seq_len 必定是 1, max_q_seq_len: int max_kv_seq_len: int max_cache_len: int = None diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 9eeab7270c..dd29c9a833 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -67,7 +67,7 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size # warmup # 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上 @@ -77,10 +77,16 @@ def _capture_decode(self, decode_func, infer_state: InferStateInfo): # 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性, # 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象 # 中的 tensor。 + for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state)) torch.cuda.synchronize() + for param_name in set(vars(infer_state).keys()): + if param_name not in pure_para_set: + delattr(infer_state, param_name) with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): @@ -100,15 +106,25 @@ def _capture_decode_overlap( graph_obj = torch.cuda.CUDAGraph() input_ids = infer_state.input_ids batch_size = input_ids.shape[0] - infer_state.max_len_in_batch = self.graph_max_len_in_batch + infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size - infer_state1.max_len_in_batch = self.graph_max_len_in_batch + infer_state1.max_kv_seq_len = self.graph_max_len_in_batch infer_state1.total_token_num = self.graph_max_len_in_batch * batch_size # warmup for _ in range(1): + # 记录原始存在的变量 + pure_para_set = set(vars(infer_state).keys()) + pure_para_set1 = set(vars(infer_state1).keys()) torch.cuda.synchronize() decode_func(copy.copy(infer_state), copy.copy(infer_state1)) torch.cuda.synchronize() + for para_name in set(vars(infer_state).keys()): + if para_name not in pure_para_set: + delattr(infer_state, para_name) + for para_name in set(vars(infer_state1).keys()): + if para_name not in pure_para_set1: + delattr(infer_state1, para_name) + with lightllm_capture_graph(dist_group1): with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): @@ -196,8 +212,7 @@ def warmup(self, model): model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, mem_indexes=mem_indexes, @@ -256,8 +271,7 @@ def warmup_overlap(self, model): is_prefill=False, batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, - max_q_seq_len=self.mtp_step + 1, + max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, b_mtp_index=b_mtp_index, diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 8e7174bb39..75856b1086 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -11,6 +11,7 @@ from .batch_objs import ModelInput from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_dp_rank +from .attention import BasePrefillAttState, BaseDecodeAttState class InferStateInfo: @@ -19,21 +20,24 @@ class InferStateInfo: """ def __init__(self): + # prefill 和 decode 使用的 att 状态对象 + self.prefill_att_state: BasePrefillAttState = None + self.decode_att_state: BaseDecodeAttState = None + + # 保留的扩展, 支持线性att与标准att混合使用时使用 + self.prefill_att_state1: BasePrefillAttState = None + self.decode_att_state1: BaseDecodeAttState = None + self.input_ids: torch.Tensor = None self.batch_size: int = None self.total_token_num: int = None self.b_req_idx: torch.Tensor = None - self.b_start_loc: torch.Tensor = None self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used. self.b_shared_seq_len: torch.Tensor = None # only for diverse mode used in decode phase. self.b_mark_shared_group: torch.Tensor = None # only for diverse mode used in decode phase. self.b_seq_len: torch.Tensor = None - # max_len_in_batch prefill 和 decode 阶段含义不同 - # prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值 - # decode 阶段指的是每个req的总长 最大值 - self.max_len_in_batch: int = None # max_cache_len 用于 prefill 阶段标识请求中最大 cache的kv 的长度 self.max_cache_len: int = None # prefix_total_token_num 用于 prefill 阶段标识当前请求中所有已经ready的kv的长度 @@ -68,6 +72,11 @@ def __init__(self): self.max_q_seq_len: int = None self.max_kv_seq_len: int = None + # prefill 用 + self.b_q_start_loc: torch.Tensor = None + # decode 用 + self.b_kv_start_loc: torch.Tensor = None + # 一些特殊模型,特殊模式使用的输入变量,本身这些变量不适合放在 # inferstate的基类中,但是为了代码的简洁和方便,都放在基类中 # 进行管理。注意这些成员变量只会在特定的模型和模式下才会生效。 @@ -90,7 +99,6 @@ def __init__(self): self.dp_input_split_sizes: List[List[int]] = None def init_some_extra_state(self, model): - if self.is_prefill: ( self.b_q_seq_len, @@ -103,7 +111,7 @@ def init_some_extra_state(self, model): b_ready_cache_len=self.b_ready_cache_len, b_seq_len=self.b_seq_len, ) - self.b_start_loc = self.b1_cu_q_seq_len[0:-1] + self.b_q_start_loc = self.b1_cu_q_seq_len[0:-1] else: ( self.b_q_seq_len, @@ -112,9 +120,17 @@ def init_some_extra_state(self, model): self.b1_cu_kv_seq_len, self.position_ids, ) = gen_decode_params(self.b_seq_len) - # TODO: check the correctness - self.max_kv_seq_len = self.max_len_in_batch - self.b_start_loc = self.b1_cu_kv_seq_len[0:-1] + self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1] + + def init_att_state(self): + if self.is_prefill: + self.prefill_att_state.init_state() + if self.prefill_att_state1 is not None: + self.prefill_att_state1.init_state() + else: + self.decode_att_state.init_state() + if self.decode_att_state1 is not None: + self.decode_att_state1.init_state() def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): for attr_name, attr_value in vars(new_infer_state).items(): @@ -122,6 +138,10 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"): attr_ = getattr(self, attr_name, None) if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr(): attr_.copy_(attr_value, non_blocking=True) + + self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state) + if self.decode_att_state1 is not None: + self.decode_att_state1.copy_for_decode_cuda_graph(new_infer_state.decode_att_state1) return def prefill_dp_balance(self, input_ids: torch.Tensor): diff --git a/lightllm/common/basemodel/layer_infer/post_layer_infer.py b/lightllm/common/basemodel/layer_infer/post_layer_infer.py index d254eb510a..c7bae26ead 100644 --- a/lightllm/common/basemodel/layer_infer/post_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/post_layer_infer.py @@ -4,8 +4,7 @@ class PostLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py index 3626346f20..e83fe89490 100644 --- a/lightllm/common/basemodel/layer_infer/pre_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/pre_layer_infer.py @@ -4,8 +4,7 @@ class PreLayerInfer(BaseLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): super().__init__() self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py index fa7e96a694..1b7813fca3 100644 --- a/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/post_layer_infer_template.py @@ -6,8 +6,8 @@ class PostLayerInferTpl(PostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 self.vocab_size_ = network_config["vocab_size"] self.embed_dim_ = network_config["n_embed"] diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index e7a0840794..04f8cda16b 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -5,8 +5,8 @@ class PreLayerInferTpl(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-5 return diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py deleted file mode 100755 index 27f71a17ec..0000000000 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ /dev/null @@ -1,136 +0,0 @@ -from functools import partial -from typing import Tuple - -import torch -import torch.distributed as dist - -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl -from lightllm.utils.infer_utils import mark_cost_time - -from ...infer_struct import InferStateInfo -from ..transformer_layer_infer import TransformerLayerInfer -from lightllm.distributed.communication_op import all_reduce - - -class TransformerLayerCohereInferTpl(TransformerLayerInferTpl): - """ """ - - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - - self.use_qk_norm_ = self.network_config_.get("use_qk_norm", False) - return - - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _bind_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - self._att_norm = partial(TransformerLayerCohereInferTpl._q_norm, self) - self._q_norm = partial(TransformerLayerCohereInferTpl._k_norm, self) - self._k_norm = partial(TransformerLayerCohereInferTpl._att_norm, self) - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - raise Exception("need to impl") - - def _bind_rotary_emb_fwd(self): - raise Exception("need to impl") - - def _get_qkv( - self, input, infer_state: InferStateInfo, layer_weight - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)) - cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view( - -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ - ) - - if self.use_qk_norm_: - q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = self._q_norm(q, infer_state, layer_weight) - cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) - self._rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: - raise Exception("need to impl") - - def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): - q, cache_kv = self._get_qkv(input_embding, infer_state, layer_weight) - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(o, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._attn_out = o - return - - def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): - ffn_out = self._ffn(input_embdings, infer_state, layer_weight) - if self.tp_world_size_ > 1: - all_reduce(ffn_out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - infer_state._ffn_out = ffn_out - return - - def _cohere_residual(self, input_embdings, infer_state: InferStateInfo): - # emb_addr = input_embdings.data_ptr() - # attn_out_addr = infer_state._attn_out.data_ptr() - # ffn_addr = infer_state._ffn_out.data_ptr() - # assert emb_addr != attn_out_addr - # assert emb_addr != ffn_addr - # assert attn_out_addr != ffn_addr - input_embdings.add_( - infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_) - ) - - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._context_attention(input1, infer_state, layer_weight=layer_weight) - self._context_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings - - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - self._token_attention(input1, infer_state, layer_weight=layer_weight) - self._token_ffn(input1, infer_state, layer_weight) - self._cohere_residual(input_embdings, infer_state) - return input_embdings diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 436ca77d84..9153349c5d 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -3,8 +3,6 @@ import torch.distributed as dist from ..transformer_layer_infer import TransformerLayerInfer from ...infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.distributed import all_reduce from typing import Tuple from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -13,8 +11,8 @@ class TransformerLayerInferTpl(TransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) # need to set by subclass self.eps_ = 1e-5 self.tp_q_head_num_ = -1 @@ -39,11 +37,11 @@ def _tpsp_get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tup def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight): mem_manager = infer_state.mem_manager - self._copy_kv_to_mem_cache(cache_kv, infer_state.mem_index, mem_manager) - return - - def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + mem_manager.copy_kv_to_mem_manager( + layer_index=self.layer_num_, + mem_index=infer_state.mem_index, + kv=cache_kv, + ) return def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py index 7350531bbb..53daffcddf 100644 --- a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py @@ -4,9 +4,8 @@ class TransformerLayerInfer(BaseLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode): + def __init__(self, layer_num, network_config): super().__init__() self.layer_num_ = layer_num self.network_config_ = network_config - self.mode = mode return diff --git a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py index 19eb67017d..8c81fd5bcc 100644 --- a/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/pre_and_post_layer_weight.py @@ -3,11 +3,10 @@ class PreAndPostLayerWeight(BaseLayerWeight): - def __init__(self, data_type, network_config, mode): + def __init__(self, data_type, network_config): super().__init__() self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.init_static_params() return diff --git a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py index 97bc762370..4bc58c76f6 100644 --- a/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/transformer_layer_weight.py @@ -9,12 +9,11 @@ class TransformerLayerWeight(BaseLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg): + def __init__(self, layer_num, data_type, network_config, quant_cfg): super().__init__() self.layer_num_ = layer_num self.data_type_ = data_type self.network_config_ = network_config - self.mode = mode self.quant_cfg = quant_cfg self._parse_config() self._init_weight_names() diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3d77a3ae4c..a8b2616418 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -168,7 +168,6 @@ def warmup(self, model): model_input = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, @@ -229,7 +228,6 @@ def warmup_overlap(self, model): micro_batch = ModelInput( batch_size=1, total_token_num=total_token_num, - max_len_in_batch=total_token_num, max_q_seq_len=total_token_num, max_kv_seq_len=total_token_num, max_cache_len=0, diff --git a/lightllm/models/chatglm2/layer_weights/__init__.py b/lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py similarity index 100% rename from lightllm/models/chatglm2/layer_weights/__init__.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/__init__.py diff --git a/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/context_flashattention_nopad.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_att1.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_att1.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_reduceV.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_reduceV.py diff --git a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py similarity index 77% rename from lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py index 25af80fabf..adf97735f6 100644 --- a/lightllm/models/bloom/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/token_attention_nopad_softmax.py @@ -6,11 +6,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -19,16 +23,22 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return @@ -44,10 +54,14 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) @@ -59,7 +73,7 @@ def test1(): import torch B, N_CTX, H, D = 4, 1025, 12, 128 - + del D dtype = torch.float16 Logics = torch.empty((H, B * N_CTX), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) @@ -85,6 +99,7 @@ def test2(): import torch B, N_CTX, H, D = 3, 1025, 12, 128 + del D dtype = torch.float16 @@ -107,7 +122,7 @@ def test2(): start = 0 for i in range(B): end = start + b_seq_len[i] - torch_o = Logics[:, start: end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) + torch_o = Logics[:, start:end].reshape(H * 1, -1).softmax(-1).reshape(H, 1 * b_seq_len[i]) start = end torch_out.append(torch_o) torch_out = torch.cat(torch_out, dim=-1) diff --git a/lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py similarity index 100% rename from lightllm/models/bloom/triton_kernel/token_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py diff --git a/lightllm/models/chatglm2/triton_kernel/__init__.py b/lightllm/common/basemodel/triton_kernel/att/__init__.py similarity index 100% rename from lightllm/models/chatglm2/triton_kernel/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/__init__.py diff --git a/lightllm/models/cohere/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py similarity index 100% rename from lightllm/models/cohere/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/__init__.py diff --git a/lightllm/models/cohere/layer_infer/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py similarity index 100% rename from lightllm/models/cohere/layer_infer/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/__init__.py diff --git a/lightllm/models/cohere/layer_weights/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py similarity index 100% rename from lightllm/models/cohere/layer_weights/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/__init__.py diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py similarity index 65% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index 67be7c968b..26ec3ebd71 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -2,11 +2,12 @@ def gqa_token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty + q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty ): BLOCK_SEQ = 128 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) from .gqa_flash_decoding_stage1 import flash_decode_stage1 @@ -15,10 +16,10 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -28,7 +29,7 @@ def gqa_token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py similarity index 96% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 320c2cf798..2814ff44bc 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -123,8 +123,18 @@ def _fwd_kernel_flash_decode_stage1( @torch.no_grad() def flash_decode_stage1( - q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq + q, + k: torch.Tensor, + v: torch.Tensor, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, ): + assert k.stride() == v.stride() BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py similarity index 65% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 81227f967b..101e99dde5 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -6,14 +6,22 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): + BLOCK_DMODEL: tl.constexpr, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -32,33 +40,43 @@ def _fwd_kernel_flash_decode_stage2( tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) - + old_scale = tl.exp(max_logic - new_max_logic) acc *= old_scale exp_logic = tl.exp(tlogic - new_max_logic) acc += exp_logic * tv sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) - + _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), + B_Seqlen, + mid_out, + mid_out_logexpsum, + out, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) - return \ No newline at end of file + return diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py similarity index 99% rename from lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 850d4185c3..6a9bb79c7d 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -421,7 +421,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = infer_state.max_len_in_batch + avg_seq_len_in_batch = infer_state.max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py similarity index 100% rename from lightllm/models/llama/triton_kernel/gqa_decode_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/gqa_decode_flashattention_nopad.py diff --git a/lightllm/models/cohere/triton_kernels/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py similarity index 100% rename from lightllm/models/cohere/triton_kernels/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py new file mode 100644 index 0000000000..212825a962 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/int4kv_flash_decoding_stage1.py @@ -0,0 +1,200 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def int4_to_float(k_int8, k_scale, offs_d): + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 + k_int4 = tl.where( + offs_d[None, :] % 2 == 0, + k_low, + k_high, + ) + k = k_int4.to(k_scale.dtype) * k_scale + return k + + +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + K_scale, + V, + V_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + quant_group_size, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + k_loc = k_loc.to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] // 2 + off_k_scale = off_k // (quant_group_size // 2) + k_int8 = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + k_scale = tl.load(K_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + k = int4_to_float(k_int8, k_scale, offs_d) + + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where((offs_n_new < cur_batch_end_index), att_value, float("-inf")) + v_int8 = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0) + v_scale = tl.load(V_scale + off_k_scale, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) + v = int4_to_float(v_int8, v_scale, offs_d) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + tl.store(Mid_O + off_mid_o, acc / sum_exp) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +@torch.no_grad() +def int4kv_flash_decode_stage1( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] * 2 + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, head_num = B_req_idx.shape[0], q.shape[1] + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + quant_group_size = Lk // k_scale.shape[-1] + assert triton.next_power_of_2(quant_group_size) == quant_group_size + assert k.stride() == v.stride() + # TODO 优化为gqa使用tensor core的实现,速度更快。 + _fwd_kernel_flash_decode_stage1[grid]( + q, + k, + k_scale, + v, + v_scale, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + quant_group_size, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py new file mode 100644 index 0000000000..a5a054b93a --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -0,0 +1,50 @@ +import torch + + +def token_decode_attention_flash_decoding( + q, + infer_state, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, +): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num = q.shape[1] + head_dim = q.shape[2] + calcu_shape1 = (batch_size, q_head_num, head_dim) + + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" + ) + + from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 + + int4kv_flash_decode_stage1( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + ) + + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + return o_tensor diff --git a/lightllm/models/mistral/triton_kernel/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py similarity index 100% rename from lightllm/models/mistral/triton_kernel/__init__.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/__init__.py diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py similarity index 71% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py index 88e39b82fc..f51d611661 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding.py @@ -5,8 +5,6 @@ def token_decode_attention_flash_decoding( q, infer_state, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -15,19 +13,20 @@ def token_decode_attention_flash_decoding( alloc_tensor_func=torch.empty, ): BLOCK_SEQ = 256 + q_head_num, head_dim = q.shape[1], q.shape[2] batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=q.dtype, device="cuda" ) light_ops.group8_int8kv_flashdecoding_stage1( @@ -43,7 +42,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py similarity index 88% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py index 84054bf867..6efb030ce6 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py @@ -10,8 +10,6 @@ def token_decode_attention_flash_decoding( q, infer_state: InferStateInfo, - q_head_num, - head_dim, cache_k, cache_k_scale, cache_v, @@ -28,18 +26,21 @@ def token_decode_attention_flash_decoding( stream1 = shared_streams_dict["stream1"] stream2 = shared_streams_dict["stream2"] + q_head_num = q.shape[1] + head_dim = q.shape[2] + BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" ) current_stream = torch.cuda.current_stream() @@ -56,7 +57,7 @@ def token_decode_attention_flash_decoding( B_req_idx=infer_state.b_req_idx, b_shared_seq_len=infer_state.b_shared_seq_len, b_mark_shared_group=infer_state.b_mark_shared_group, - max_len_in_batch=infer_state.max_len_in_batch, + max_len_in_batch=infer_state.max_kv_seq_len, mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, @@ -78,7 +79,7 @@ def token_decode_attention_flash_decoding( infer_state.b_req_idx, infer_state.b_seq_len, infer_state.b_shared_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) current_stream.wait_stream(stream1) diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py similarity index 98% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py index 8b3423ce99..7403f6dd5c 100644 --- a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py @@ -269,12 +269,13 @@ def flash_decode_stage1( gqa_group_size = q.shape[1] // k.shape[1] assert triton.next_power_of_2(Lk) == Lk KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] - assert KV_QUANT_GROUP_SIZE == 8 + assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) if BLOCK_HEAD * BLOCK_BATCH < 16: BLOCK_BATCH = 16 // BLOCK_HEAD + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, stride_qbs=q.stride(0), diff --git a/lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py similarity index 63% rename from lightllm/models/phi3/triton_kernel/flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py index e47e308864..6c50fc3927 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py @@ -1,12 +1,11 @@ import torch -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + q_head_num, head_dim = q.shape[1], q.shape[2] calcu_shape1 = (batch_size, q_head_num, head_dim) from .flash_decoding_stage1 import flash_decode_stage1 @@ -15,10 +14,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" ) flash_decode_stage1( @@ -28,7 +27,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, mid_o, mid_o_logexpsum, BLOCK_SEQ, diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py similarity index 88% rename from lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py index f6d8b5abee..f41a5c8fde 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage1.py @@ -33,7 +33,6 @@ def _fwd_kernel_flash_decode_stage1( stride_mid_o_eh, stride_mid_o_es, gqa_group_size, - head_dim, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -62,7 +61,7 @@ def _fwd_kernel_flash_decode_stage1( offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - q = tl.load(Q + off_q, mask=offs_d < head_dim, other=0.0) + q = tl.load(Q + off_q) sum_exp = 0.0 max_logic = -float("inf") @@ -75,16 +74,13 @@ def _fwd_kernel_flash_decode_stage1( mask=offs_n_new < cur_batch_end_index, other=0, ) + k_loc = k_loc.to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - k = tl.load( - K + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) + k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) - v = tl.load( - V + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 - ) + v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) cur_max_logic = tl.max(att_value, axis=0) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -101,7 +97,7 @@ def _fwd_kernel_flash_decode_stage1( for _ in range(0, need_store, 1): off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp, mask=offs_d < head_dim) + tl.store(Mid_O + off_mid_o, acc / sum_exp) tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -116,13 +112,12 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) + assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + assert k.stride() == v.stride() _fwd_kernel_flash_decode_stage1[grid]( q, k, @@ -152,9 +147,8 @@ def flash_decode_stage1( mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), gqa_group_size, - head_dim, BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, num_warps=1, num_stages=2, diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py similarity index 81% rename from lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py index a06ee54545..101e99dde5 100644 --- a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding_stage2.py @@ -8,7 +8,7 @@ def _fwd_kernel_flash_decode_stage2( B_Seqlen, Mid_O, # [batch, head, seq_block_num, head_dim] Mid_O_LogExpSum, # [batch, head, seq_block_num] - Out, # [batch, head, head_dim] + O, # [batch, head, head_dim] stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -19,7 +19,6 @@ def _fwd_kernel_flash_decode_stage2( stride_obs, stride_oh, stride_od, - head_dim, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -38,7 +37,7 @@ def _fwd_kernel_flash_decode_stage2( offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0) + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) new_max_logic = tl.maximum(tlogic, max_logic) @@ -49,23 +48,22 @@ def _fwd_kernel_flash_decode_stage2( sum_exp = sum_exp * old_scale + exp_logic max_logic = new_max_logic - tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim) + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) return @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] - head_dim = Lk + assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] - BLOCK_DMODEL = triton.next_power_of_2(head_dim) grid = (batch, head_num) _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, mid_out, mid_out_logexpsum, - Out, + out, mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), @@ -73,12 +71,11 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_dim, + out.stride(0), + out.stride(1), + out.stride(2), BLOCK_SEQ=block_seq, - BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DMODEL=Lk, num_warps=4, num_stages=2, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py similarity index 63% rename from lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py index 9a8261132a..9de2b82057 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_att1.py @@ -14,8 +14,6 @@ def _fwd_kernel_token_att1( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, Att_Out, stride_req_to_tokens_b, stride_req_to_tokens_s, @@ -28,7 +26,6 @@ def _fwd_kernel_token_att1( att_stride_h, att_stride_bs, kv_group_num, - sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -38,38 +35,32 @@ def _fwd_kernel_token_att1( cur_kv_head = cur_head // kv_group_num - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) - # use new start index of k value - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) + cur_batch_start_index = 0 cur_batch_end_index = cur_batch_seq_len - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - # use new value to decide block mask block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark - offs_n_new = cur_batch_start_index + offs_n # the latest window of token + q = tl.load(Q + off_q + start_mark) + offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0, - ) - off_k = ( - k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - ) # [32, D], find token index + ).to(tl.int64) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] - att_value = att_value.to(tl.float32) + att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) @@ -77,19 +68,17 @@ def _fwd_kernel_token_att1( @torch.no_grad() -def token_att_fwd( - q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): +def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256} sm_scale = 1.0 / (Lk ** 0.5) batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK)) + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) kv_group_num = q.shape[1] // k.shape[1] if kv_group_num == 1: @@ -105,8 +94,6 @@ def token_att_fwd( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, att_out, Req_to_tokens.stride(0), Req_to_tokens.stride(1), @@ -119,7 +106,6 @@ def token_att_fwd( att_out.stride(0), att_out.stride(1), kv_group_num=kv_group_num, - sliding_window=sliding_window, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py similarity index 59% rename from lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py index acf4923f82..96a5b26dd6 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_reduceV.py @@ -13,8 +13,6 @@ def _fwd_kernel_token_att2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, stride_req_to_tokens_b, stride_req_to_tokens_s, stride_ph, @@ -26,7 +24,6 @@ def _fwd_kernel_token_att2( stride_oh, stride_od, kv_group_num, - sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -35,36 +32,30 @@ def _fwd_kernel_token_att2( cur_kv_head = cur_head // kv_group_num - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - # cur_batch_end_index = cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index + cur_batch_start_index = 0 + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # att length - v_loc_off = ( - cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - ) # the latest window of value [64] - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs # [64] - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] + v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s + p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs + v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_att_seq_len, other=0.0) # [64] + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) v_loc = tl.load( Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n + cur_batch_start_index) < cur_batch_seq_len, + mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0, - ) # [64] + ).to(tl.int64) v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, - other=0.0, - ) # [1, D] + [64, 1] = [64, D] - acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D] + V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 + ) + acc += tl.sum(p_value[:, None] * v_value, 0) acc = acc.to(Out.dtype.element_ty) off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od @@ -74,9 +65,7 @@ def _fwd_kernel_token_att2( @torch.no_grad() -def token_att_fwd2( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window -): +def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): BLOCK = 128 # BLOCK = 64 # for triton 2.0.0dev batch, head = B_req_idx.shape[0], prob.shape[0] @@ -94,8 +83,6 @@ def token_att_fwd2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, Req_to_tokens.stride(0), Req_to_tokens.stride(1), prob.stride(0), @@ -107,7 +94,6 @@ def token_att_fwd2( out.stride(1), out.stride(2), kv_group_num=kv_group_num, - siliding_window=sliding_window, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py similarity index 71% rename from lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py index 5e6040ac55..0bb6410e13 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_softmax.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_nopad_softmax.py @@ -5,11 +5,15 @@ @triton.jit def _fwd_kernel_token_softmax( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - stride_logic_h, stride_logic_bs, - stride_prob_h, stride_prob_bs, - BLOCK_SIZE: tl.constexpr + stride_logic_h, + stride_logic_bs, + stride_prob_h, + stride_prob_bs, + BLOCK_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -18,18 +22,25 @@ def _fwd_kernel_token_softmax( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32) + row = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, + mask=col_offsets < cur_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) - * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len) + tl.store( + Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, + softmax_output, + mask=col_offsets < cur_batch_seq_len, + ) return + @torch.no_grad() def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): BLOCK_SIZE = triton.next_power_of_2(max_input_len) @@ -42,20 +53,26 @@ def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): num_warps = 16 _fwd_kernel_token_softmax[(batch, head_num)]( - Logics, B_Start_Loc, B_Seqlen, + Logics, + B_Start_Loc, + B_Seqlen, Prob_Out, - Logics.stride(0), Logics.stride(1), - Prob_Out.stride(0), Prob_Out.stride(1), + Logics.stride(0), + Logics.stride(1), + Prob_Out.stride(0), + Prob_Out.stride(1), num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) return + def test1(): import torch B, N_CTX, H, D = 4, 1025, 12, 128 + del D dtype = torch.float16 diff --git a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py similarity index 100% rename from lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/mha/stage3_decode_att/token_attention_softmax_and_reducev.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py similarity index 51% rename from lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py index 8fda084605..b0a9b6245c 100644 --- a/lightllm/models/llama/triton_kernel/ppl_fp16_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/ppl_fp16/ppl_fp16_flash_decoding.py @@ -1,27 +1,27 @@ import torch +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): +def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty): BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + q_head_num = q.shape[1] + head_dim = q.shape[2] + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) - from lightllm_ppl_fp16_flashdecoding_kernel import fp16_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 + from ..mha.flash_decoding.flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" ) - fp16_flashdecoding_stage1( + light_ops.fp16_flashdecoding_stage1( BLOCK_SEQ, mid_o, mid_o_logexpsum, @@ -32,7 +32,7 @@ def token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_kv_seq_len, ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py similarity index 81% rename from lightllm/models/llama/triton_kernel/context_flashattention_nopad.py rename to lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index e36c51b394..5ba6d0beb6 100644 --- a/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -336,25 +336,24 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma @triton.jit -def _fwd_kernel_int8kv( +def _fwd_kernel_contiguous_kv( Q, K, V, sm_scale, Out, B_Start_Loc, + B_kv_start_loc, B_Seqlen, b_prompt_cache_len, stride_qbs, stride_qh, stride_qd, - stride_kb, - stride_kh, stride_ks, + stride_kh, stride_kd, - stride_vb, - stride_vh, stride_vs, + stride_vh, stride_vd, stride_obs, stride_oh, @@ -374,6 +373,7 @@ def _fwd_kernel_int8kv( prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + kv_start_loc = tl.load(B_kv_start_loc + cur_batch) block_start_loc = BLOCK_M * start_m @@ -393,6 +393,9 @@ def _fwd_kernel_int8kv( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + stride_ks = tl.cast(stride_ks, tl.int64) + stride_vs = tl.cast(stride_vs, tl.int64) + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) # causal mask @@ -405,8 +408,7 @@ def _fwd_kernel_int8kv( # other=0, # ) off_k = ( - cur_batch * stride_kb - + (start_n + offs_n[None, :]) * stride_ks + (kv_start_loc + start_n + offs_n[None, :]) * stride_ks + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd ) @@ -432,8 +434,7 @@ def _fwd_kernel_int8kv( # other=0.0, # ) off_v = ( - cur_batch * stride_vb - + (start_n + offs_n[:, None]) * stride_vs + (kv_start_loc + start_n + offs_n[:, None]) * stride_vs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd ) @@ -455,7 +456,9 @@ def _fwd_kernel_int8kv( @torch.no_grad() -def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len): +def context_attention_fwd_contiguous_kv( + q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, max_q_input_len, b_prompt_cache_len +): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -468,34 +471,33 @@ def context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_inp batch, head = b_seq_len.shape[0], q.shape[1] kv_group_num = q.shape[1] // k.shape[1] - grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + grid = lambda meta: (triton.cdiv(max_q_input_len, meta["BLOCK_M"]), batch * head, 1) BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 - _fwd_kernel_int8kv[grid]( - q, - k, - v, - sm_scale, - o, - b_start_loc, - b_seq_len, - b_prompt_cache_len, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - k.stride(3), - v.stride(0), - v.stride(1), - v.stride(2), - v.stride(3), - o.stride(0), - o.stride(1), - o.stride(2), + _fwd_kernel_contiguous_kv[grid]( + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + Out=o, + B_Start_Loc=b_start_loc, + B_kv_start_loc=b_kv_start_loc, + B_Seqlen=b_seq_len, + b_prompt_cache_len=b_prompt_cache_len, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + stride_ks=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + stride_vs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + stride_obs=o.stride(0), + stride_oh=o.stride(1), + stride_od=o.stride(2), kv_group_num=kv_group_num, H=head, BLOCK_DMODEL=Lk, @@ -596,86 +598,5 @@ def test(): assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) -def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_seq_len, b_prompt_cache_len): - - batch = b_start_loc.shape[0] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - for i in range(batch): - start_loc = b_start_loc[i] - seq_len = b_seq_len[i] - prompt_cache_len = b_prompt_cache_len[i] - cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] - cur_q = cur_q.clone().to(torch.float32) - cur_k = k[i, :seq_len, :] - cur_k = cur_k.clone().to(torch.float32) - - cur_v = v[i, :seq_len, :] - cur_v = cur_v.clone().to(torch.float32) - - cur_q = cur_q.transpose(0, 1) - cur_k = cur_k.transpose(0, 1) - cur_v = cur_v.transpose(0, 1) - dk = cur_q.shape[-1] - - p = torch.matmul(cur_q, cur_k.transpose(-2, -1)) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) - - q_index = torch.arange(cur_q.shape[1]).unsqueeze(-1).to(p.device) - k_index = torch.arange(cur_k.shape[1]).unsqueeze(0).to(p.device) - mask = (q_index + prompt_cache_len >= k_index).int() - mask = mask.unsqueeze(0).expand(cur_q.shape[0], -1, -1) - - p = p.masked_fill(mask == 0, float("-inf")) - - s = F.softmax(p, dim=-1) - - o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) - - -def test2(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 16, 16, 2048, 128 - dtype = torch.float16 - prompt_cache_len = 0 - q = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - kv = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - k = kv[:, :H] - v = kv[:, H:] - # v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - torch_o = torch.empty((Z * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.3, std=0.2 - ) - max_input_len = N_CTX - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(Z, dtype=torch.int32, device="cuda") - - for i in range(Z): - b_seq_len[i] = N_CTX - if i != 0: - b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len - b_prompt_cache_len[i] = prompt_cache_len - torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_seq_len, b_prompt_cache_len) - - import time - - torch.cuda.synchronize() - a = time.time() - for i in range(1000): - context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len) - torch.cuda.synchronize() - b = time.time() - # print(o.shape, torch_out.shape) - print((b - a)) - - print("max ", torch.max(torch.abs(torch_o - o))) - print("mean ", torch.mean(torch.abs(torch_o - o))) - assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) - - if __name__ == "__main__": test() - test2() diff --git a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py index bd53de386e..060a92bf7c 100644 --- a/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py @@ -16,12 +16,13 @@ def _fwd_kernel_destindex_copy_kv( stride_o_h, stride_o_d, head_num, + head_dim, BLOCK_DMODEL: tl.constexpr, BLOCK_HEAD: tl.constexpr, ): cur_index = tl.program_id(0) offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = (tl.arange(0, BLOCK_DMODEL)) % head_dim dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) @@ -54,133 +55,10 @@ def destindex_copy_kv(K, DestLoc, Out): Out.stride(1), Out.stride(2), head_num, - BLOCK_DMODEL=head_dim, + head_dim, + BLOCK_DMODEL=triton.next_power_of_2(head_dim), BLOCK_HEAD=BLOCK_HEAD, num_warps=num_warps, num_stages=1, ) return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=offs_h[:, None] < head_num, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num) - tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - BLOCK_DMODEL=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index 9804e46681..c8a6a850bc 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -9,8 +9,6 @@ def gen_decode_params(b_seq_len: torch.Tensor): b_kv_seq_len = b_seq_len position_ids = b_seq_len - 1 - mtp_step = get_env_start_args().mtp_step - mtp_size = mtp_step + 1 - b_q_seq_len = torch.ones(b_seq_len.shape[0] // mtp_size, dtype=torch.int32, device=b_seq_len.device) * mtp_size - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size]) + b_q_seq_len = torch.ones(b_seq_len.shape[0], dtype=torch.int32, device=b_seq_len.device) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index e73b342994..8f9172b552 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -43,6 +43,7 @@ def _gen_cumsum_pad0_kernel( def gen_cumsum_pad0_tensor(b_q_seq_len: torch.Tensor, b_kv_seq_len: torch.Tensor): assert len(b_q_seq_len.shape) == 1 assert b_q_seq_len.shape == b_kv_seq_len.shape + assert b_q_seq_len.is_contiguous() b1_cu_q_seq_len = torch.empty((b_q_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") b1_cu_kv_seq_len = torch.empty((b_kv_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py b/lightllm/common/basemodel/triton_kernel/kv_copy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py similarity index 92% rename from lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py rename to lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py index 39deb1b6f7..41a25877a7 100644 --- a/lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/mla_copy_kv.py @@ -36,11 +36,11 @@ def _fwd_kernel_destindex_copy_kv( dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :] - kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :] + kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope - o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :] - o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :] + o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope kv_nope = tl.load(kv_nope_ptrs) kv_rope = tl.load(kv_rope_ptrs) @@ -60,6 +60,9 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope): assert KV_nope.shape[2] == O_nope.shape[2] assert KV_rope.shape[1] == O_rope.shape[1] assert KV_rope.shape[2] == O_rope.shape[2] + assert triton.next_power_of_2(kv_nope_head_dim) == kv_nope_head_dim + assert triton.next_power_of_2(kv_rope_head_dim) == kv_rope_head_dim + grid = (seq_len,) num_warps = 1 diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py new file mode 100644 index 0000000000..53d1256ec7 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int4kv_copy_kv.py @@ -0,0 +1,374 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_int4_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_count, + token_num, + HEAD_NUM: tl.constexpr, + BLOCK_GROUP_COUNT: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_index = tl.program_id(0) + + for cur_index in range(start_index, token_num, step=tl.num_programs(axis=0)): + offs_g = tl.arange(0, BLOCK_GROUP_COUNT) % group_count + offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + for cur_head in tl.static_range(HEAD_NUM, step=1): + src_data_0 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2, + ) + src_data_1 = tl.load( + K + + cur_index * stride_k_bs + + cur_head * stride_k_h + + offs_g[:, None] * stride_k_g + + offs_d[None, :] * 2 + + 1, + ) + + abs_data_0 = tl.abs(src_data_0) + abs_data_1 = tl.abs(src_data_1) + + data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to( + Out_scale.dtype.element_ty + ) + q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) + q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) + q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) + q_src_data_0 += 7 + q_src_data_0 = q_src_data_0.to(tl.uint8, bitcast=True) + + q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) + q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) + q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) + q_src_data_1 += 7 + q_src_data_1 = q_src_data_1.to(tl.uint8, bitcast=True) + + low_4 = q_src_data_0 & 0xF + high_4 = (q_src_data_1 & 0xF) << 4 + + out_data = (low_4 | high_4).to(Out.dtype.element_ty, bitcast=True) + + o_ptrs = ( + Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + ) + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, out_data) + tl.store(os_ptrs, data_scale) + return + + +@torch.no_grad() +def destindex_copy_int4kv( + KV: torch.Tensor, + DestLoc: torch.Tensor, + KV_buffer: torch.Tensor, + KV_scale_buffer: torch.Tensor, + quant_group_size: int, +): + head_num = KV.shape[1] + head_dim = KV.shape[2] + + assert head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + + group_count = head_dim // quant_group_size + group_dim = quant_group_size + + assert triton.next_power_of_2(group_dim) == group_dim + + KV = KV.view((KV.shape[0], head_num, group_count, group_dim)) + KV_buffer = KV_buffer.view( + KV_buffer.shape[0], KV_buffer.shape[1], group_count, group_dim // 2 + ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 + KV_scale_buffer = KV_scale_buffer.view(KV_scale_buffer.shape[0], KV_scale_buffer.shape[1], group_count) + if len(DestLoc) < 1024: + grid = (len(DestLoc),) + else: + grid = (1024,) + + _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( + K=KV, + Dest_loc=DestLoc, + Out=KV_buffer, + Out_scale=KV_scale_buffer, + stride_k_bs=KV.stride(0), + stride_k_h=KV.stride(1), + stride_k_g=KV.stride(2), + stride_k_d=KV.stride(3), + stride_o_bs=KV_buffer.stride(0), + stride_o_h=KV_buffer.stride(1), + stride_o_g=KV_buffer.stride(2), + stride_o_d=KV_buffer.stride(3), + stride_os_bs=KV_scale_buffer.stride(0), + stride_os_h=KV_scale_buffer.stride(1), + stride_os_g=KV_scale_buffer.stride(2), + group_count=group_count, + token_num=len(DestLoc), + HEAD_NUM=head_num, + BLOCK_GROUP_COUNT=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=group_dim, + num_warps=4, + num_stages=1, + ) + return + + +@triton.jit +def int4_to_float(k_int8, offs_d): + k_int8 = k_int8.to(tl.uint8, bitcast=True) + k_high = (k_int8 & 0xF0) >> 4 + k_low = k_int8 & 0x0F + k_high = k_high.to(tl.int8, bitcast=True) + k_low = k_low.to(tl.int8, bitcast=True) + k_high -= 7 + k_low -= 7 + + k_int4 = tl.where( + offs_d[None, None, :] % 2 == 0, + k_low, + k_high, + ) + return k_int4 + + +@triton.jit +def _fwd_dequantize_int4kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] // 2 + ) + k_int4 = int4_to_float(k_int8, offs_d) + + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int4.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] // 2 + ) + v_int4 = int4_to_float(v_int8, offs_d) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int4.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int4kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] * 2 + v_head_num = v.shape[1] + v_head_dim = v.shape[2] * 2 + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + assert triton.next_power_of_2(group_dim) == group_dim + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim // 2)) # int4kv 以 int8 存储的 + v = v.view((v.shape[0], v.shape[1], group_count, group_dim // 2)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int4kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=group_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py new file mode 100644 index 0000000000..e5ee5cb8b8 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_copy/ppl_int8kv_copy_kv.py @@ -0,0 +1,330 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_g, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_g, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_g, + group_size, + BLOCK_GROUP_NUM: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_g = tl.arange(0, BLOCK_GROUP_NUM) + offs_d = tl.arange(0, BLOCK_GROUP_DIM) + + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + src_data = tl.load( + K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], + mask=offs_g[:, None] < group_size, + other=0.0, + ) + abs_data = tl.abs(src_data) + data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) + q_src_data = (src_data / data_scale[:, None]).to(tl.int8) + + o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] + os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g + tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) + tl.store(os_ptrs, data_scale, mask=offs_g < group_size) + return + + +@torch.no_grad() +def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale, quant_group_dim): + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + assert triton.next_power_of_2(quant_group_dim) == quant_group_dim, "error quant group dim" + + assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" + grid = (seq_len, head_num) + num_warps = 1 + + group_size = head_dim // quant_group_dim + group_dim = quant_group_dim + + K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) + Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) + + _fwd_kernel_destindex_copy_quantize_kv[grid]( + K, + DestLoc, + Out, + Out_scale, + K.stride(0), + K.stride(1), + K.stride(2), + K.stride(3), + Out.stride(0), + Out.stride(1), + Out.stride(2), + Out.stride(3), + Out_scale.stride(0), + Out_scale.stride(1), + Out_scale.stride(2), + group_size, + BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), + BLOCK_GROUP_DIM=group_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_dequantize_int8kv( + k, + k_ss, + k_sh, + k_sg, + k_sd, + k_scale, + k_scale_ss, + k_scale_sh, + k_scale_sg, + k_scale_sd, + v, + v_ss, + v_sh, + v_sg, + v_sd, + v_scale, + v_scale_ss, + v_scale_sh, + v_scale_sg, + v_scale_sd, + req_to_token_indexs, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + b_seq_len, + b_req_idx, + b_kv_start_loc, + k_out, + k_out_ss, + k_out_sh, + k_out_sg, + k_out_sd, + v_out, + v_out_ss, + v_out_sh, + v_out_sg, + v_out_sd, + k_head_num, + v_head_num, + group_count, + group_dim, + SEQ_BLOCK_SIZE: tl.constexpr, + GROUP_COUNT_BLOCK_SIZE: tl.constexpr, + BLOCK_GROUP_DIM: tl.constexpr, +): + start_block_index = tl.program_id(0) + cur_batch = tl.program_id(1) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_seq_len = tl.load(b_seq_len + cur_batch) + if start_block_index * SEQ_BLOCK_SIZE >= cur_seq_len: + return + + out_start_loc = tl.load(b_kv_start_loc + cur_batch) + + offs_kv_loc = (start_block_index * SEQ_BLOCK_SIZE + tl.arange(0, SEQ_BLOCK_SIZE)) % cur_seq_len + kv_loc = tl.load(req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc).to(tl.int64) + + offs_d = tl.arange(0, BLOCK_GROUP_DIM) % group_dim + offs_scale_d = tl.arange(0, 1) + group_offs = tl.arange(0, GROUP_COUNT_BLOCK_SIZE) % group_count + + for k_head_index in tl.range(0, k_head_num, step=1, num_stages=3): + k_int8 = tl.load( + k + + kv_loc[:, None, None] * k_ss + + k_head_index * k_sh + + group_offs[None, :, None] * k_sg + + offs_d[None, None, :] + ) + k_scale_data = tl.load( + k_scale + + kv_loc[:, None, None] * k_scale_ss + + k_head_index * k_scale_sh + + group_offs[None, :, None] * k_scale_sg + + offs_scale_d[None, None, :] + ) + k_out_data = k_int8.to(k_out.dtype.element_ty) * k_scale_data + tl.store( + k_out + + (out_start_loc + offs_kv_loc[:, None, None]) * k_out_ss + + k_head_index * k_out_sh + + group_offs[None, :, None] * k_out_sg + + offs_d[None, None, :], + k_out_data, + ) + + for v_head_index in tl.range(0, v_head_num, step=1, num_stages=3): + v_int8 = tl.load( + v + + kv_loc[:, None, None] * v_ss + + v_head_index * v_sh + + group_offs[None, :, None] * v_sg + + offs_d[None, None, :] + ) + v_scale_data = tl.load( + v_scale + + kv_loc[:, None, None] * v_scale_ss + + v_head_index * v_scale_sh + + group_offs[None, :, None] * v_scale_sg + + offs_scale_d[None, None, :] + ) + v_out_data = v_int8.to(v_out.dtype.element_ty) * v_scale_data + tl.store( + v_out + + (out_start_loc + offs_kv_loc[:, None, None]) * v_out_ss + + v_head_index * v_out_sh + + group_offs[None, :, None] * v_out_sg + + offs_d[None, None, :], + v_out_data, + ) + return + + +@torch.no_grad() +def dequantize_int8kv( + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_req_idx: torch.Tensor, + b_kv_start_loc: torch.Tensor, + k_out: torch.Tensor, + v_out: torch.Tensor, + max_len_in_batch: int, + quant_group_size: int, +): + batch_size = b_seq_len.shape[0] + k_head_num = k.shape[1] + k_head_dim = k.shape[2] + v_head_num = v.shape[1] + v_head_dim = v.shape[2] + assert k_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert v_head_dim % quant_group_size == 0, "error head dim, can not been supported to copy quant kv" + assert k_head_dim == v_head_dim, "error head dim, can not been supported to copy quant kv" + assert k_head_dim // v_scale.shape[2] == quant_group_size, "error head dim, can not been supported to copy quant kv" + assert k_head_dim in [64, 128, 256] + + group_count = k_head_dim // quant_group_size + group_dim = quant_group_size + + k = k.view((k.shape[0], k.shape[1], group_count, group_dim)) + v = v.view((v.shape[0], v.shape[1], group_count, group_dim)) + k_scale = k_scale.view((k_scale.shape[0], k_scale.shape[1], group_count, 1)) + v_scale = v_scale.view((v_scale.shape[0], v_scale.shape[1], group_count, 1)) + + # 使拆分的grid 具有足够的并行度 + SEQ_BLOCK_SIZE = 128 + while triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE) * batch_size < 512: + SEQ_BLOCK_SIZE = SEQ_BLOCK_SIZE // 2 + if SEQ_BLOCK_SIZE <= 1: + break + + if SEQ_BLOCK_SIZE <= 1: + SEQ_BLOCK_SIZE = 8 + + grid = (triton.cdiv(max_len_in_batch, SEQ_BLOCK_SIZE), batch_size) + num_warps = 4 + k_out = k_out.view((k_out.shape[0], k_out.shape[1], group_count, group_dim)) + v_out = v_out.view((v_out.shape[0], v_out.shape[1], group_count, group_dim)) + + _fwd_dequantize_int8kv[grid]( + k=k, + k_ss=k.stride(0), + k_sh=k.stride(1), + k_sg=k.stride(2), + k_sd=k.stride(3), + k_scale=k_scale, + k_scale_ss=k_scale.stride(0), + k_scale_sh=k_scale.stride(1), + k_scale_sg=k_scale.stride(2), + k_scale_sd=k_scale.stride(2), + v=v, + v_ss=v.stride(0), + v_sh=v.stride(1), + v_sg=v.stride(2), + v_sd=v.stride(3), + v_scale=v_scale, + v_scale_ss=v_scale.stride(0), + v_scale_sh=v_scale.stride(1), + v_scale_sg=v_scale.stride(2), + v_scale_sd=v_scale.stride(3), + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + stride_req_to_tokens_s=req_to_token_indexs.stride(1), + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=k_out, + k_out_ss=k_out.stride(0), + k_out_sh=k_out.stride(1), + k_out_sg=k_out.stride(2), + k_out_sd=k_out.stride(3), + v_out=v_out, + v_out_ss=v_out.stride(0), + v_out_sh=v_out.stride(1), + v_out_sg=v_out.stride(2), + v_out_sd=v_out.stride(3), + k_head_num=k_head_num, + v_head_num=v_head_num, + group_count=group_count, + group_dim=group_dim, + SEQ_BLOCK_SIZE=SEQ_BLOCK_SIZE, + GROUP_COUNT_BLOCK_SIZE=triton.next_power_of_2(group_count), + BLOCK_GROUP_DIM=triton.next_power_of_2(group_dim), + num_warps=num_warps, + num_stages=1, + ) + return + + +def test2(): + import time + + B, N_CTX, H, D = 1, 3, 12, 128 + src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() + value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) + scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() + + for _ in range(10): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t1 = time.time() + for _ in range(1000): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t2 = time.time() + + print("Time cost ", t2 - t1) + value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) + scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) + print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) + cos = torch.nn.CosineSimilarity(0) + print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py new file mode 100644 index 0000000000..fb0609a401 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/__init__.py @@ -0,0 +1 @@ +from .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py similarity index 92% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 256dfce5af..28839b5f59 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -12,28 +12,21 @@ def gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv_nope, - kv_rope, - infer_state, - q_head_num, - kv_lora_rank, - q_rope_dim, - qk_nope_head_dim, - softmax_scale, - out=None, - alloc_tensor_func=torch.empty, - **run_config + q_nope, q_rope, kv_nope, kv_rope, infer_state, softmax_scale, out=None, alloc_tensor_func=torch.empty, **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len + + q_head_num, kv_lora_rank = q_nope.shape[1], q_nope.shape[2] + q_rope_dim = q_rope.shape[2] + assert q_rope_dim == 64 + calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_config.py diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage1.py diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py rename to lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding_stage2.py diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py new file mode 100644 index 0000000000..5725bed2e7 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/__init__.py @@ -0,0 +1 @@ +from .context_flashattention_nopad_with_v import context_attention_fwd_with_v diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py b/lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py rename to lightllm/common/basemodel/triton_kernel/mla_att/prefill_att/context_flashattention_nopad_with_v.py diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py similarity index 100% rename from lightllm/models/deepseek2/triton_kernel/repack_kv_index.py rename to lightllm/common/basemodel/triton_kernel/repack_kv_index.py diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 66caf5d789..7d516e6728 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -1,20 +1,16 @@ from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager -from .int8kv_mem_manager import INT8KVMemoryManager from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager from .export_calibration_mem_manager import ExportCalibrationMemoryManager from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager -from .deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager __all__ = [ "MemoryManager", "ReadOnlyStaticsMemoryManager", - "INT8KVMemoryManager", "CalibrationFP8KVMemoryManager", "ExportCalibrationMemoryManager", "PPLINT4KVMemoryManager", "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", - "Deepseek2FP8KVMemoryManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py deleted file mode 100644 index 00699f4b15..0000000000 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_fp8kv_mem_manager.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from .deepseek2_mem_manager import Deepseek2MemoryManager - - -class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): - # scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8 - super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 7711734601..3d93e1b070 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -3,13 +3,14 @@ import torch.distributed as dist from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager -from typing import List, Union +from typing import List, Union, Any from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io + logger = init_logger(__name__) @@ -17,6 +18,29 @@ class Deepseek2MemoryManager(MemoryManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank + rope_dim == self.kv_buffer.shape[-1] + + destindex_copy_kv( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + self.kv_buffer[layer_index][:, :, :kv_lora_rank], + self.kv_buffer[layer_index][:, :, kv_lora_rank:], + ) + return + + def get_att_input_params(self, layer_index: int) -> Any: + kv = self.kv_buffer[layer_index] + return kv + def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py index b2749176ea..ffdc9b2c94 100755 --- a/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py @@ -1,6 +1,28 @@ +import torch +from typing import Tuple, Any from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager class ExportCalibrationMemoryManager(OfflineFP8QuantMemManager): def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=True) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 + + scales = self.scales + destindex_copy_kv_fp8( + kv, + mem_index, + scales[layer_index] if scales is not None else None, + self.kv_buffer[layer_index].view(torch.float8_e4m3fn), + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v diff --git a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py deleted file mode 100755 index 5725cdb7bb..0000000000 --- a/lightllm/common/kv_cache_mem_manager/int8kv_mem_manager.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch - -from .mem_manager import MemoryManager - - -class INT8KVMemoryManager(MemoryManager): - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=0.9): - self.kv_dtype = torch.int8 - super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True, mem_fraction=mem_fraction) - - def get_cell_size(self): - return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( - self.kv_dtype - ) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype) - - def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") - self.scale_buffer = torch.empty((layer_num, size + 1, 2 * head_num, 1), dtype=dtype, device="cuda") - - def _free_buffers(self): - self.kv_buffer = None - self.scale_buffer = None - - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"]) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009f..1203cbdec7 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,7 +3,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from typing import List, Union +from typing import List, Union, Tuple, Any from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -21,6 +21,7 @@ from multiprocessing.reduction import ForkingPickler from filelock import FileLock + logger = init_logger(__name__) @@ -65,6 +66,20 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False ) self.HOLD_TOKEN_MEMINDEX = self.size + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv + + destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + return k, v + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 259c5a56f8..1ff58b89a0 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -1,12 +1,10 @@ from . import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -18,8 +16,6 @@ @lru_cache(maxsize=None) def select_mem_manager_class(): - mode = get_env_start_args().mode - # case 1 # 先判断是否是 deepseek 系列的模型 model_class = get_llm_model_class() @@ -27,38 +23,25 @@ def select_mem_manager_class(): if issubclass(model_class, Deepseek2TpPartModel): mem_class = Deepseek2MemoryManager - if "triton_fp8kv" in mode: - mem_class = Deepseek2FP8KVMemoryManager - - logger.info(f"Model kv cache using mode {mode}, mem_manager class: {mem_class}") + logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class # case normal - logger.info(f"mode setting params: {mode}") - if "ppl_int8kv" in mode or "ppl_int8kv_flashdecoding" in mode or "ppl_int8kv_flashdecoding_diverse" in mode: + logger.info(f"mode setting params: {get_env_start_args().llm_kv_type}") + if get_env_start_args().llm_kv_type == "int8kv": memory_manager_class = PPLINT8KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "ppl_int4kv_flashdecoding" in mode: + elif get_env_start_args().llm_kv_type == "int4kv": memory_manager_class = PPLINT4KVMemoryManager - logger.info(f"Model kv cache using mode {mode}") - elif "triton_int8kv" in mode: - memory_manager_class = INT8KVMemoryManager - logger.info("Model kv cache using mode triton int8kv") - elif "triton_fp8kv" in mode: - raise Exception("currently only for deepseek") - elif "offline_calibration_fp8kv" in mode: - memory_manager_class = CalibrationFP8KVMemoryManager - logger.info("Model kv cache using mode offline calibration fp8kv") - elif "export_fp8kv_calibration" in mode: + elif get_env_start_args().llm_kv_type == "fp8kv": memory_manager_class = ExportCalibrationMemoryManager - logger.info("Using mode export fp8kv calibration") - else: + elif get_env_start_args().llm_kv_type == "None": memory_manager_class = MemoryManager - logger.info("Model kv cache using mode normal") + + logger.info(f"Model kv cache using mem_manager class: {memory_manager_class}") return memory_manager_class @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, INT8KVMemoryManager] + return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager] diff --git a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py index 5cc0b12d03..56a79a3b57 100755 --- a/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/offline_fp8_quant_mem_manager.py @@ -31,8 +31,10 @@ def __init__( self.scales_list = None self.abs_max = None + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if is_export_mode: - scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2] + scales_shape = [layer_num, 2 * head_num] if enable_fa3 else [layer_num, 2] self.abs_max = torch.zeros(scales_shape, dtype=torch.float32, device="cuda") elif get_env_start_args().kv_quant_calibration_config_path is not None: logger.info( @@ -43,7 +45,7 @@ def __init__( self.scales_list = cfg["scales"] self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) - if not get_env_start_args().enable_fa3: + if not enable_fa3: self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1) elif cfg["num_head"] > self.total_head_num: factor = cfg["num_head"] // self.total_head_num @@ -51,7 +53,7 @@ def __init__( elif cfg["num_head"] < self.total_head_num: factor = self.total_head_num // cfg["num_head"] self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() - if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: + if enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1: half_head = self.total_head_num // 2 start_head = dist.get_rank() * head_num end_head = start_head + head_num @@ -65,6 +67,8 @@ def __init__( logger.warning("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales") def _load_and_check_config(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + if os.path.exists(get_env_start_args().kv_quant_calibration_config_path): with open(get_env_start_args().kv_quant_calibration_config_path, "r") as f: cfg = json.load(f) @@ -86,7 +90,7 @@ def _load_and_check_config(self): raise ValueError( f"num_head {cfg['num_head']} in config " f"not match current model head num {self.total_head_num}" ) - if get_env_start_args().enable_fa3: + if enable_fa3: if cfg["quant_type"] != "per_head": raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend") else: @@ -100,6 +104,7 @@ def _load_and_check_config(self): ) def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend inference_counts = get_kv_quant_calibration_inference_count() warmup_counts = get_kv_quant_calibration_warmup_count() if not get_model_init_status() or self.count >= warmup_counts + inference_counts: @@ -109,7 +114,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): logger.info("kv cache calibration mode will collect kv cache data for quantization calibration") if self.abs_max is not None and self.count >= warmup_counts: - if get_env_start_args().enable_fa3: + if enable_fa3: kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) else: k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32) @@ -119,7 +124,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): if self.count == warmup_counts + inference_counts - 1 and layer_index == self.layer_num - 1: final_abs_max = self.abs_max if dist.is_initialized() and dist.get_world_size() > 1: - if get_env_start_args().enable_fa3: + if enable_fa3: k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) k_max = k_max.contiguous() v_max = v_max.contiguous() @@ -144,11 +149,13 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int): self.count += 1 def _export_calibration_data(self): + enable_fa3 = "fa3" in get_env_start_args().llm_prefill_att_backend + model_arch = get_model_architectures(get_env_start_args().model_dir) cfg = { "version": "1.0", "architectures": model_arch, - "quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor", + "quant_type": "per_head" if enable_fa3 else "per_tensor", "qmin": self.qmin, "qmax": self.qmax, "num_layers": self.layer_num, diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index f3218594d3..559980dc12 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -9,6 +9,28 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv + + destindex_copy_int4kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_size=self.group_quant_size, + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 2a5aad7c8b..951d72e2c8 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,5 +1,5 @@ import torch - +from typing import Tuple, Any from .mem_manager import MemoryManager @@ -9,6 +9,28 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + """ + 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 + """ + from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv + + destindex_copy_quantize_kv( + kv, + mem_index, + self.kv_buffer[layer_index], + self.scale_buffer[layer_index], + quant_group_dim=self.group_quant_size, + ) + return + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + k = self.kv_buffer[layer_index][:, : self.head_num, :] + k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] + v = self.kv_buffer[layer_index][:, self.head_num :, :] + v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] + return (k, k_scale), (v, v_scale) + def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( self.kv_dtype diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 4ee02f003b..539b32decb 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -1,4 +1,3 @@ -from lightllm.models.cohere.model import CohereTpPartModel from lightllm.models.mixtral.model import MixtralTpPartModel from lightllm.models.bloom.model import BloomTpPartModel from lightllm.models.llama.model import LlamaTpPartModel @@ -8,7 +7,6 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel -from lightllm.models.chatglm2.model import ChatGlm2TpPartModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index 7938869f5a..f4fff116cd 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -10,9 +10,9 @@ class BloomPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): repair_config(config=network_config, same_names=["layer_norm_epsilon", "rms_norm_eps"]) - super().__init__(network_config, mode) + super().__init__(network_config) return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index baf1d3084d..dfe396ab52 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class BloomPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["layer_norm_epsilon"] return diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index d82a23d039..808788f71a 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -2,16 +2,15 @@ from typing import Tuple from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight -from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd -from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd from lightllm.common.basemodel import InferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl class BloomTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["layer_norm_epsilon"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = self.tp_q_head_num_ @@ -21,6 +20,40 @@ def __init__(self, layer_num, network_config, mode): self.embed_dim_ = network_config["n_embed"] return + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: BloomTransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor + + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl(use_alibi=True, tp_alibi=layer_weight.tp_alibi), + alloc_func=self.alloc_tensor, + ) + return o_tensor.view(q.shape) + def _att_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: @@ -42,47 +75,6 @@ def _get_qkv( cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) return q, cache_kv - def _context_attention_kernel( - self, q, kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - layer_weight.tp_alibi, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _token_attention_kernel( - self, q, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - token_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - layer_weight.tp_alibi, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.total_token_num, - alloc_tensor_func=self.alloc_tensor, - ) - return o_tensor - def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_)) return o_tensor diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index afc8c93081..83f7674531 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.pre_norm_weight_ = NoTpNormWeight( weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 7b27ce6f2c..599893655d 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -48,8 +48,8 @@ def get_slopes_power_of_2(n): class BloomTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode, quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/bloom/model.py b/lightllm/models/bloom/model.py index 7e44ec2ebf..925620bf96 100644 --- a/lightllm/models/bloom/model.py +++ b/lightllm/models/bloom/model.py @@ -5,6 +5,7 @@ from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend @ModelRegistry("bloom") @@ -35,3 +36,8 @@ def _init_config(self): def _reset_num_key_value_heads(self): self.config["num_key_value_heads"] = self.config["num_attention_heads"] return + + def _init_att_backend(self): + self.prefill_att_backend = TritonAttBackend(self) + self.decode_att_backend = TritonAttBackend(self) + return diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 07ffc4beab..0000000000 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight - - -class ChatGLM2TransformerLayerInfer(LlamaTransformerLayerInfer): - """ """ - - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) - return - - def swiglu(self, x): - x = torch.chunk(x, 2, dim=-1) - return torch.nn.functional.silu(x[0]) * x[1] - - def _ffn( - self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - input = None - ffn1_out = self.swiglu(up_gate_out) - up_gate_out = None - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index 0139eb8837..0000000000 --- a/lightllm/models/chatglm2/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,20 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class ChatGLM2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - - self.wte_weight_ = EmbeddingWeight( - weight_name="transformer.embedding.word_embeddings.weight", data_type=self.data_type_ - ) - self.lm_head_weight_ = LMHeadWeight( - weight_name="transformer.output_layer.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="transformer.encoder.final_layernorm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py deleted file mode 100755 index d4dd1b7a29..0000000000 --- a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,72 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight - - -class ChatGLM2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__( - layer_num, - data_type, - network_config, - mode, - quant_cfg, - ) - return - - def _preprocess_weight(self, weights): - n_kv_embed = self.head_dim * self.n_kv_head - qkv_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" - if qkv_weight_name in weights: - qkv_weight_ = weights[qkv_weight_name] - weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :] - weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :] - weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :] - del weights[qkv_weight_name] - - qkv_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" - if qkv_bias_name in weights: - qkv_bias_ = weights[qkv_bias_name] - weights[self._q_bias_name] = qkv_bias_[: self.n_embed] - weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed] - weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed] - del weights[qkv_bias_name] - - gate_up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_h_to_4h.weight" - if gate_up_weight_name in weights: - gate_up_weight_ = weights[gate_up_weight_name] - weights[self._gate_weight_name] = gate_up_weight_[: self.n_inter, :] - weights[self._up_weight_name] = gate_up_weight_[self.n_inter : 2 * self.n_inter, :] - del weights[gate_up_weight_name] - - def _parse_config(self): - self.n_embed = self.network_config_["hidden_size"] - self.n_head = self.network_config_["num_attention_heads"] - self.n_inter = self.network_config_["ffn_hidden_size"] - self.n_kv_head = self.network_config_["multi_query_group_num"] - self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head) - - def load_hf_weights(self, weights): - self._preprocess_weight(weights) - super().load_hf_weights(weights) - return - - def _init_weight_names(self): - self._q_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.weight" - self._q_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.q_proj.bias" - self._k_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.weight" - self._k_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.k_proj.bias" - self._v_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.weight" - self._v_bias_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.v_proj.bias" - self._o_weight_name = f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" - self._o_bias_name = None - - self._gate_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.gate_proj.weight" - self._gate_bias_name = None - self._up_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.up_proj.weight" - self._up_bias_name = None - self._down_weight_name = f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" - self._down_bias_name = None - - self._att_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.input_layernorm.weight" - self._att_norm_bias_name = None - self._ffn_norm_weight_name = f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" - self._ffn_norm_bias_name = None diff --git a/lightllm/models/chatglm2/model.py b/lightllm/models/chatglm2/model.py deleted file mode 100644 index e6aa395275..0000000000 --- a/lightllm/models/chatglm2/model.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -import json -import torch - -from lightllm.models.registry import ModelRegistry -from lightllm.models.chatglm2.layer_infer.transformer_layer_infer import ChatGLM2TransformerLayerInfer -from lightllm.models.chatglm2.layer_weights.transformer_layer_weight import ChatGLM2TransformerLayerWeight -from lightllm.models.chatglm2.layer_weights.pre_and_post_layer_weight import ChatGLM2PreAndPostLayerWeight -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.common.build_utils import repair_config -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("chatglm") -class ChatGlm2TpPartModel(LlamaTpPartModel): - # Please use the fast tokenizer from: - # [THUDM/chatglm3-6b PR #12](https://huggingface.co/THUDM/chatglm3-6b/discussions/12). - - # weight class - pre_and_post_weight_class = ChatGLM2PreAndPostLayerWeight - transformer_weight_class = ChatGLM2TransformerLayerWeight - - # infer class - transformer_layer_infer_class = ChatGLM2TransformerLayerInfer - - def __init__(self, kvargs): - super().__init__(kvargs) - - def _init_config(self): - super()._init_config() - # rename key - # repair_config() - repair_config(self.config, same_names=["num_hidden_layers", "n_layer", "num_layers"]) - repair_config(self.config, same_names=["vocab_size", "padded_vocab_size"]) - repair_config(self.config, same_names=["rms_norm_eps", "layernorm_epsilon"]) - repair_config(self.config, same_names=["seq_length", "max_sequence_length"]) - return - - def _reset_num_key_value_heads(self): - self.config["num_key_value_heads"] = self.config["multi_query_group_num"] - return - - def _verify_params(self): - assert self.load_way == "HF", "ChatGLM only support HF format for now" - assert self.tp_world_size_ in [1, 2], "ChatGLM can only run in tp=1 or tp=2 for now" - - def _init_to_get_rotary(self, base=10000): - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_seq_len = self.config.get("max_position_embeddings", 2048) * rope_scaling_factor - - base = float(base) * self.config.get("rope_ratio", 1.0) - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula - except: - pass - n_elem = self.head_dim_ // 2 - inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) - t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py b/lightllm/models/chatglm2/triton_kernel/rotary_emb.py deleted file mode 100755 index ad1d1c2cf0..0000000000 --- a/lightllm/models/chatglm2/triton_kernel/rotary_emb.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/cohere/infer_struct.py b/lightllm/models/cohere/infer_struct.py deleted file mode 100644 index d9571af92b..0000000000 --- a/lightllm/models/cohere/infer_struct.py +++ /dev/null @@ -1,8 +0,0 @@ -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - - -class CohereInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self._attn_out = None - self._ffn_out = None diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py deleted file mode 100644 index 67987a8d3b..0000000000 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import numpy as np -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward -from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer -from lightllm.common.build_utils import repair_config -from lightllm.distributed.communication_op import all_gather - - -class CoherePostLayerInfer(LlamaPostLayerInfer): - def __init__(self, network_config, mode): - repair_config(config=network_config, same_names=["layer_norm_eps", "rms_norm_eps"]) - super().__init__(network_config, mode) - self.eps_ = network_config["layer_norm_eps"] - self.logits_scale = network_config["logit_scale"] - return - - def _norm( - self, input: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ) -> torch.Tensor: - return layernorm_forward( - input.unsqueeze(1), layer_weight.final_norm_weight_.weight.unsqueeze(0), eps=self.eps_ - ).squeeze(1) - - def token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) - input_embdings_dtype = input_embdings.dtype - input_embdings = None - last_input = self._norm(last_input, infer_state, layer_weight) - last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) - last_input = None - vocab_size = layer_weight.lm_head_weight_.vocab_size - if self.tp_world_size_ == 1: - gather_data = logic_batch - else: - gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) - split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - all_gather( - [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], - logic_batch, - group=infer_state.dist_group, - async_op=False, - ) - gather_data = gather_data * self.logits_scale - logic_batch = None - ans_logics = self.alloc_tensor( - (token_num, vocab_size), - dtype=torch.float32, - ) - ans_logics[:, :] = gather_data.permute(1, 0) - gather_data = None - return ans_logics - - def tpsp_token_forward( - self, input_embdings: torch.Tensor, infer_state: CohereInferStateInfo, layer_weight: CoherePreAndPostLayerWeight - ): - raise NotImplementedError("not impl") - - def overlap_tpsp_token_forward( - self, - input_embdings: torch.Tensor, - input_embdings1: torch.Tensor, - infer_state: CohereInferStateInfo, - infer_state1: CohereInferStateInfo, - layer_weight: CoherePreAndPostLayerWeight, - ): - raise NotImplementedError("not impl") diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py deleted file mode 100644 index 0cdd281a37..0000000000 --- a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from functools import partial - -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm -from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd - - -class CohereTransformerLayerInfer(TransformerLayerCohereInferTpl): - def __init__(self, layer_num, network_config, mode): - super().__init__(layer_num, network_config, mode) - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - self.eps_ = self.network_config_["layer_norm_eps"] - self.use_qk_norm_ = network_config.get("use_qk_norm", False) - self._bind_func() - - def _bind_func(self): - self._bind_rotary_emb_fwd() - self._bind_norm() - self._bind_attn() - - def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): - return rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv, - position_cos, - position_sin, - ) - - def _bind_rotary_emb_fwd(self): - self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self) - - def _att_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward( - input.unsqueeze(1), layer_weight.att_norm_weight_.weight.unsqueeze(0), self.eps_ - ).squeeze(1) - - def _q_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.q_norm_weight_.weight, self.eps_) - - def _k_norm(self, input, infer_state, layer_weight: CohereTransformerLayerWeight): - return layernorm_forward(input, layer_weight.k_norm_weight_.weight, self.eps_) - - def _bind_norm(self): - self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) - self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) - self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) - - def _bind_attn(self): - # no need to re-impl - LlamaTransformerLayerInfer._bind_attention(self) - - def _get_o( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - # o_tensor = layer_weight.mm_op.apply(input, layer_weight.o_weight_) - o_tensor = layer_weight.o_proj.mm(input) - return o_tensor - - def _ffn( - self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - input = None - up_gate_out = None - # ffn2_out = layer_weight.mm_op.apply(ffn1_out, layer_weight.down_proj) - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py deleted file mode 100644 index f2e5f85472..0000000000 --- a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight - - -class CoherePreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) - tie_weight = self.network_config_.get("tie_word_embeddings", True) - - self.wte_weight_ = EmbeddingWeight( - weight_name="model.embed_tokens.weight", - data_type=self.data_type_, - ) - if tie_weight: - self.lm_head_weight_ = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="model.lm_head.weight", - data_type=self.data_type_, - ) - self.final_norm_weight_ = NoTpNormWeight( - weight_name="model.norm.weight", - data_type=self.data_type_, - bias_name=None, - ) diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py deleted file mode 100644 index 9c446b49e9..0000000000 --- a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight, TpHeadNormWeight - - -class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) - return - - def _parse_config(self): - super()._parse_config() - self.use_qk_norm = self.network_config_.get("use_qk_norm", False) - - def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight(self._att_norm_weight_name, self.data_type_) - - if self.use_qk_norm: - self.q_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_norm.weight", self.data_type_ - ) - self.k_norm_weight_ = TpHeadNormWeight( - f"model.layers.{self.layer_num_}.self_attn.k_norm.weight", self.data_type_ - ) - - return diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py deleted file mode 100644 index 5b317c1331..0000000000 --- a/lightllm/models/cohere/model.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import torch -from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( - TransformerLayerCohereInferTpl, -) -from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.models.registry import ModelRegistry -from lightllm.models.cohere.infer_struct import CohereInferStateInfo -from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer -from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer -from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -@ModelRegistry("cohere") -class CohereTpPartModel(LlamaTpPartModel): - pre_and_post_weight_class = CoherePreAndPostLayerWeight - transformer_weight_class = CohereTransformerLayerWeight - - pre_layer_infer_class = LlamaPreLayerInfer - transformer_layer_infer_class = CohereTransformerLayerInfer - post_layer_infer_class = CoherePostLayerInfer - - infer_state_class = CohereInferStateInfo - - def _init_to_get_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - - base = self.config.get("rope_theta", float(default_base)) - - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_position_embeddings = self.config.get( - "max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384 - ) - max_seq_len = max_position_embeddings * rope_scaling_factor - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula - except: - pass - - inv_freq = 1.0 / ( - base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) - ) - t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor - freqs = torch.outer(t, inv_freq) - freqs = torch.repeat_interleave(freqs, 2, dim=-1) - - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py deleted file mode 100644 index c1d5ff4cd6..0000000000 --- a/lightllm/models/cohere/triton_kernels/layernorm.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import triton -import triton.language as tl - -# LayerNorm adapted from triton tutorial, used for Cohere q, k norm -# X [N, head_num, head_dim] -# W [head_num, head_dim] -@triton.jit -def _layer_norm_fwd_kernel( - X, # pointer to the input - W, # pointer to the weights - Y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, -): - Seq = tl.program_id(0) - H = tl.program_id(1) - - X += Seq * stride_x_N + H * stride_x_hn - Y += Seq * stride_y_N + H * stride_y_hn - W += H * stride_w_hn - - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N - - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.0) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w - - tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask) - - -def layernorm_forward( - X, # pointer to the input - W, # pointer to the weights - eps, # epsilon to avoid division by zero -): - assert len(X.shape) == 3 - assert len(W.shape) == 2 - assert X.shape[-1] == W.shape[-1] - assert X.shape[-2] == W.shape[-2] - - y = torch.empty_like(X) - - stride_x_N = X.stride(0) - stride_x_hn = X.stride(1) - stride_x_hd = X.stride(2) - - stride_y_N = y.stride(0) - stride_y_hn = y.stride(1) - stride_y_hd = y.stride(2) - - stride_w_hn = W.stride(0) - stride_w_hd = W.stride(1) - - N = X.shape[-1] - BLOCK_SIZE = 128 - - grid = (X.shape[0], X.shape[1]) - _layer_norm_fwd_kernel[grid]( - X, - W, - y, - stride_x_N, - stride_x_hn, - stride_x_hd, - stride_y_N, - stride_y_hn, - stride_y_hd, - stride_w_hn, - stride_w_hd, - N, - eps, - BLOCK_SIZE, - ) - - return y - - -def torch_layernorm(x, weight, eps): - inp_dtype = x.dtype - x = x.to(torch.float32) - mean = x.mean(-1, keepdim=True) - variance = (x - mean).pow(2).mean(-1, keepdim=True) - x = (x - mean) * torch.rsqrt(variance + eps) - x = weight.to(torch.float32) * x - return x.to(inp_dtype) - - -def test_layernorm(eps=1e-5): - # create data - dtype = torch.float16 - x_shape = (5, 1, 128) - w_shape = (x_shape[-2], x_shape[-1]) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - # forward pass - y_ref = torch_layernorm(x, weight, eps).to(dtype) - y_out = layernorm_forward(x, weight, eps) - - # compare - print("type:", y_out.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_out - y_ref))) - assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0) - return diff --git a/lightllm/models/cohere/triton_kernels/rotary_emb.py b/lightllm/models/cohere/triton_kernels/rotary_emb.py deleted file mode 100644 index ac338e71ef..0000000000 --- a/lightllm/models/cohere/triton_kernels/rotary_emb.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - Q, - K, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + 1 - - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos0 - q1 * sin0 - out1 = q0 * sin1 + q1 * cos1 - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - - cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos0 - k1 * sin0 - out_k1 = k0 * sin1 + k1 * cos1 - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - return - - -def torch_cohere_rotary_emb(x, cos, sin): - dtype = x.dtype - seq_len, h, dim = x.shape - x = x.float() - x1 = x[:, :, ::2] - x2 = x[:, :, 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - cos = cos.view((seq_len, 1, dim)) - sin = sin.view((seq_len, 1, dim)) - o = (x * cos) + (rot_x * sin) - return o.to(dtype=dtype) - - -@torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): - total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = int(q.shape[2] * partial_rotary_factor) - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" - - BLOCK_SEQ = 16 - BLOCK_HEAD = 4 - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( - q, - k, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num_q, - head_num_k, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): - # create data - x_shape = (SEQ_LEN, H, D) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - y = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - cos_shape = (SEQ_LEN, D) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - # forward pass - y_tri = torch_cohere_rotary_emb(x, cos, sin) - rotary_emb_fwd(x, y, cos, sin) - y_ref = x - - # compare - print("type:", y_tri.dtype, y_ref.dtype) - print("max delta:", torch.max(torch.abs(y_tri - y_ref))) - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py deleted file mode 100644 index 72ba8a43b1..0000000000 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - args_mtp_step = get_env_start_args().mtp_step - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - self.has_prefix_kv = self.max_cache_len > 0 - if self.has_prefix_kv: - self.cu_seqlens_prefix_k = torch.nn.functional.pad( - torch.cumsum(self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0) - ) - self.prefix_k_max_len = self.max_cache_len - self.prefix_total_token_num = self.prefix_total_token_num - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len - self.cu_seqlens_k = self.b1_cu_kv_seq_len - max_seq_len_k = self.max_kv_seq_len - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].view(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty((att_batch_size, self.max_len_in_batch), dtype=torch.int32).to( - self.input_ids.device - ) - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py deleted file mode 100644 index db6386f797..0000000000 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class Deepseek2FlashInferStateInfo(Deepseek2InferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(self.input_ids.device) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - self.flashinfer_extra_state.workspace_buffer, - use_cuda_graph=True, - qo_indptr=self.q_indptr, - kv_indices=self.kv_indices, - kv_indptr=self.kv_starts, - kv_len_arr=self.b_seq_len, - ) - self.decode_wrapper.plan( - self.q_indptr, - self.kv_starts, - self.kv_indices, - self.b_seq_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.kv_lora_rank, - self.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - self.flashinfer_extra_state.softmax_scale, - self.flashinfer_extra_state.q_data_type, - self.flashinfer_extra_state.kv_data_type, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_kv_start_loc.int() - if self.prefill_wrapper is None: - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, "NHD" - ) - self.prefill_wrapper.plan( - qo_indptr=q_starts, - kv_indptr=kv_starts, - num_qo_heads=self.flashinfer_extra_state.tp_q_head_num, - num_kv_heads=self.flashinfer_extra_state.tp_q_head_num, - head_dim_qk=self.flashinfer_extra_state.qk_nope_head_dim - + self.flashinfer_extra_state.qk_rope_head_dim, - head_dim_vo=self.flashinfer_extra_state.qk_nope_head_dim, - q_data_type=self.flashinfer_extra_state.q_data_type, - causal=True, - sm_scale=self.flashinfer_extra_state.softmax_scale, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.q_indptr, - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.b_seq_len, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.kv_lora_rank, - new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, - False, # causal - new_infer_state.flashinfer_extra_state.softmax_scale, - new_infer_state.flashinfer_extra_state.q_data_type, - new_infer_state.flashinfer_extra_state.kv_data_type, - ) - return diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 0c2ef30489..4dd79305c3 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -1,21 +1,6 @@ -import os -import torch -import numpy as np -import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo class Deepseek2InferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() - self.kv_starts = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - if not self.is_prefill: - self.kv_starts = self.b1_cu_kv_seq_len - - if self.is_prefill: - self.b1_kv_start_loc = self.b1_cu_kv_seq_len - self.max_value_in_b_seq_len = self.max_kv_seq_len - return diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ff20bc6ee6..8695f2de89 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -1,40 +1,25 @@ import os import torch -import torch.functional as F import torch.distributed as dist -import numpy as np import triton -from typing import Tuple from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8 -from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv -from lightllm.models.deepseek2.triton_kernel.repeat_rope import repeat_rope -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger -from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 logger = init_logger(__name__) class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self.qk_nope_head_dim = network_config["qk_nope_head_dim"] @@ -66,7 +51,7 @@ def __init__(self, layer_num, network_config, mode=[]): mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] return @@ -89,58 +74,81 @@ def _bind_ffn(self): self._ffn = partial(LlamaTransformerLayerInfer._ffn, self) self._tpsp_ffn = self._tpsp_ffn_tp - def _bind_attention(self): - if "triton_fp8kv" in self.mode: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self) - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self - ) - else: - self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) - elif get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self - ) - if self.enable_cc_method: - if "triton_fp8kv" in self.mode: - if get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC_fp8, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self - ) - else: - if get_env_start_args().enable_fa3: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self - ) - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self - ) - else: - if "triton_fp8kv" in self.mode: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self - ) - else: - self._context_attention_kernel = partial( - Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self - ) + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + k_nope, k_rope, v = self._decompress_kv( + infer_state=infer_state, + layer_weight=layer_weight, + ) + + o_tensor = infer_state.prefill_att_state.prefill_att( + q=q, + k=(k_nope, k_rope), + v=v, + att_control=AttControl(mla_prefill=True, mla_prefill_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return o_tensor + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + + out = infer_state.decode_att_state.decode_att( + q=(q_nope, q_rope), + k=kv, + v=None, + att_control=AttControl(mla_decode=True, mla_decode_dict={"softmax_scale": self.softmax_scale}), + alloc_func=self.alloc_tensor, + ) + return out + + def _decompress_kv( + self, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + ): + compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + + total_token_num = infer_state.total_token_num + sampled_compressed_kv_nope = self.alloc_tensor( + [total_token_num, 1, layer_weight.kv_lora_rank], dtype=compressed_kv.dtype + ) + sampled_k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=compressed_kv.dtype) + sample_kv( + all_compressed_kv=compressed_kv, + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + sampled_k_rope=sampled_k_rope, + b_req_idx=infer_state.b_req_idx, + req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, + b_seq_len=infer_state.b_seq_len, + b_kv_start_loc=infer_state.b1_cu_kv_seq_len[:-1], + max_kv_seq_len=infer_state.max_kv_seq_len, + ) + # CC + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view( + total_token_num, layer_weight.kv_lora_rank + ).contiguous() + sampled_kv_nope = self.alloc_tensor( + [total_token_num, self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], + dtype=sampled_compressed_kv_nope.dtype, + ) + layer_weight.cc_kv_b_proj_.mm(sampled_compressed_kv_nope, out=sampled_kv_nope.view(total_token_num, -1)) + sampled_k_nope, sampled_v = torch.split(sampled_kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + return sampled_k_nope, sampled_k_rope, sampled_v def _get_qkv( self, @@ -297,423 +305,6 @@ def _tpsp_get_o( return o_tensor - def _decompress_kv( - self, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - is_fp8, - total_token_num, - b_seq_len, - max_seq_len, - b_kv_start_loc, - skip_sample=False, - ): - if not skip_sample: - if is_fp8: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype) - else: - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = None - k_scale = None - - compressed_kv = self.alloc_tensor([total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype) - k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype) - sample_kv( - kv, - compressed_kv, - k_rope, - infer_state.b_req_idx, - max_seq_len, - b_seq_len, - infer_state.req_manager.req_to_token_indexs, - b_kv_start_loc, - kv_scale, - k_scale, - ) - if k_scale is not None: - compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1) - k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1) - else: - compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r) - kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1 - ) - - # CC - compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous() - kv_nope = self.alloc_tensor( - [compressed_kv.shape[0], self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)], - dtype=compressed_kv.dtype, - ) - layer_weight.cc_kv_b_proj_.mm(compressed_kv, out=kv_nope.reshape(compressed_kv.shape[0], -1)) - k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - return k_nope, k_rope, v - - # Adapted from: - # https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962 - def _context_attention_flashattention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashAttentionStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - skip_sample=True, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - o_tensor, lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_q, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.max_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - return_softmax_lse=True, - ) - if infer_state.has_prefix_kv: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.prefix_total_token_num, - infer_state.b_ready_cache_len, - infer_state.prefix_k_max_len, - infer_state.cu_seqlens_prefix_k, - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - prefix_output, prefix_lse, *rest = flash_attn_varlen_func( - q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim), - v=v.view(-1, self.tp_v_head_num_, self.v_head_dim), - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k=infer_state.cu_seqlens_prefix_k, - max_seqlen_q=infer_state.q_max_seq_len, - max_seqlen_k=infer_state.prefix_k_max_len, - softmax_scale=self.softmax_scale, - causal=False, - return_softmax_lse=True, - ) - lse = torch.transpose(lse, 0, 1).contiguous() - prefix_lse = torch.transpose(prefix_lse, 0, 1).contiguous() - tmp_output = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) - if out is None - else out - ) - tmp_lse = torch.empty_like(lse) - merge_state_v2(prefix_output, prefix_lse, o_tensor, lse, tmp_output, tmp_lse) - o_tensor = tmp_output - return o_tensor - - def _context_attention_flashinfer_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - - def _context_attention_flashinfer_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2FlashInferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - o_tensor = ( - self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out - ) - k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) - infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) - return o_tensor - - def _context_attention_kernel_with_CC( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - False, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_with_CC_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - k_nope, k_rope, v = self._decompress_kv( - kv, - infer_state, - layer_weight, - True, - infer_state.total_token_num, - infer_state.b_seq_len, - infer_state.max_value_in_b_seq_len, - infer_state.b1_kv_start_loc, - ) - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - context_attention_fwd_with_v( - q_nope, - q_rope, - k_nope, - k_rope, - v, - o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), - infer_state.b_start_loc, - infer_state.b1_kv_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - - def _context_attention_kernel_origin_fp8( - self, - q: torch.Tensor, - kv, - infer_state: Deepseek2InferStateInfo, - layer_weight: Deepseek2TransformerLayerWeight, - out=None, - ) -> torch.Tensor: - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - context_attention_fwd_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - self.softmax_scale, - ) - return o_tensor - - def _token_gqa_decode_attention_flashattention( - self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) - k_descale, v_descale = None, None - o_tensor = flash_attn_with_kvcache( - q=q_rope, - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=self.softmax_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashinfer( - self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) - - infer_state.decode_wrapper.run( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - out=o_tensor, - return_lse=False, - ) - return o_tensor - - def _token_gqa_decode_attention_flashdecoding( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - out = gqa_token_decode_attention_flash_decoding( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, - ) - return out - - def _token_gqa_decode_attention_flashdecoding_fp8( - self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None - ): - q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] - q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - - kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn) - kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16) - return gqa_token_decode_attention_flash_decoding_fp8( - q_nope, - q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], - kv_scale, - infer_state, - self.tp_q_head_num_, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.qk_nope_head_dim, - self.softmax_scale, - alloc_tensor_func=self.alloc_tensor, - ) - - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank], - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank :], - ) - return - - def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager): - destindex_copy_kv_fp8( - buffer[:, :, : self.kv_lora_rank], - buffer[:, :, self.kv_lora_rank :], - mem_index, - mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2].view(torch.float8_e4m3fn), - mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(buffer.dtype), - ) - return - def _moe_ffn( self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight ) -> torch.Tensor: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 611878f9e8..c5a2d33527 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -17,9 +17,9 @@ class Deepseek2TransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e4ce7c8269..f0739a8a81 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -4,51 +4,16 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo -from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo -from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights - from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id - +from lightllm.common.basemodel.attention import get_mla_decode_att_backend_class, get_mla_prefill_att_backend_class logger = init_logger(__name__) -class DeepSeek2FlashInferStateExtraInfo: - def __init__(self, model): - num_heads = model.config["num_attention_heads"] - self.tp_q_head_num = num_heads // get_dp_world_size() - self.qk_nope_head_dim = model.qk_nope_head_dim - self.qk_rope_head_dim = model.qk_rope_head_dim - self.kv_lora_rank = model.kv_lora_rank - self.q_data_type = model.data_type - self.kv_data_type = model.data_type - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - if model.config["rope_scaling"] is not None: - rope_scaling = model.config["rope_scaling"] - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) - scaling_factor = rope_scaling["factor"] - if mscale_all_dim: - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - @ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class @@ -61,18 +26,13 @@ class Deepseek2TpPartModel(LlamaTpPartModel): infer_state_class = Deepseek2InferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = Deepseek2FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = Deepseek2FlashInferStateInfo - self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self) + def _init_att_backend(self): + self.prefill_att_backend = get_mla_prefill_att_backend_class(index=0)(model=self) + self.decode_att_backend = get_mla_decode_att_backend_class(index=0)(model=self) + return def _init_some_value(self): super()._init_some_value() diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py index fd437c3888..b9be73e278 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py @@ -5,7 +5,6 @@ import triton.language as tl from typing import List from lightllm.utils.log_utils import init_logger -from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig from lightllm.utils.device_utils import get_device_sm_count logger = init_logger(__name__) @@ -28,16 +27,18 @@ def gqa_token_decode_attention_flash_decoding_fp8( **run_config ): batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch + max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, kv_lora_rank) calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: if torch.cuda.is_current_stream_capturing(): - avg_seq_len_in_batch = max_len_in_batch + avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size + from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig + run_config = MlaDecodeAttentionKernelConfig.try_to_get_best_config( batch_size=batch_size, avg_seq_len_in_batch=avg_seq_len_in_batch, @@ -191,7 +192,7 @@ def _fwd_kernel_calcu_index_and_block_seq( infer_state = Deepseek2InferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_kv_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX infer_state.req_manager = ReqManager(Z, N_CTX, None) infer_state.req_manager.req_to_token_indexs = req_to_token_indexs diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index af0aaa2f66..53a0a60eb2 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -8,111 +8,101 @@ @triton.jit def _sample_kv_kernel( - KV_input, - KV_scale, - KV_nope, - KV_rope, - K_scale, - B_start_loc, - B_Seqlen, - Req_to_tokens, - B_req_idx, - stride_input_dim, - stride_scale_dim, - stride_nope_dim, - stride_rope_dim, + all_compressed_kv, + stride_all_s, + stride_all_d, + sampled_compressed_kv_nope, + stride_nope_s, + stride_nope_d, + sampled_k_rope, + stride_rope_s, + stride_rope_d, + b_kv_start_loc, + b_seq_len, + req_to_token_indexs, stride_req_to_tokens_b, - HAS_SCALE: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, + b_req_idx, + BLOCK_SEQ: tl.constexpr, + BLOCK_NOPE_DIM: tl.constexpr, + BLOCK_ROPE_DIM: tl.constexpr, ): cur_batch = tl.program_id(0) start_m = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_loc = tl.load(B_start_loc + cur_batch) + cur_batch_seq_len = tl.load(b_seq_len + cur_batch) + cur_batch_req_idx = tl.load(b_req_idx + cur_batch) + cur_batch_start_loc = tl.load(b_kv_start_loc + cur_batch) - offs_nope_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_nope_d = tl.arange(0, BLOCK_NOPE_DIM) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DIM) + offs_m = (start_m * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)) % cur_batch_seq_len - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + if start_m * BLOCK_SEQ > cur_batch_seq_len: + return kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, - mask=offs_m < block_end_loc, - other=0, + req_to_token_indexs + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, ).to(tl.int64) - off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] - off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] - kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) - kv_rope = tl.load(KV_input + off_kv_rope, mask=offs_m[:, None] < block_end_loc, other=0.0) - off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_dim + offs_nope_d[None, :] - off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_dim + offs_rope_d[None, :] - nope_ptrs = KV_nope + off_nope - rope_ptrs = KV_rope + off_rope - tl.store(nope_ptrs, kv_nope, mask=offs_m[:, None] < block_end_loc) - tl.store(rope_ptrs, kv_rope, mask=offs_m[:, None] < block_end_loc) - if HAS_SCALE: - kv_scale = tl.load(KV_scale + kv_loc * stride_scale_dim, mask=offs_m < block_end_loc) - off_k_scale = cur_batch_start_loc + offs_m - k_scale_ptrs = K_scale + off_k_scale - tl.store(k_scale_ptrs, kv_scale, mask=offs_m < block_end_loc) + off_kv_nope = kv_loc[:, None] * stride_all_s + offs_nope_d[None, :] + off_kv_rope = kv_loc[:, None] * stride_all_s + (offs_rope_d + BLOCK_NOPE_DIM)[None, :] + kv_nope = tl.load(all_compressed_kv + off_kv_nope) + kv_rope = tl.load(all_compressed_kv + off_kv_rope) + off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_s + offs_nope_d[None, :] + off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_s + offs_rope_d[None, :] + nope_ptrs = sampled_compressed_kv_nope + off_nope + rope_ptrs = sampled_k_rope + off_rope + tl.store(nope_ptrs, kv_nope) + tl.store(rope_ptrs, kv_rope) return @torch.no_grad() def sample_kv( - kv_input, - kv_nope, - kv_rope, - b_req_idx, - max_value_in_b_seq_len, - b_seq_len, - req_to_token_indexs, - b_kv_start_loc, - kv_scale=None, - k_scale=None, + all_compressed_kv: torch.Tensor, + sampled_compressed_kv_nope: torch.Tensor, + sampled_k_rope: torch.Tensor, + b_req_idx: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_seq_len: torch.Tensor, + b_kv_start_loc: torch.Tensor, + max_kv_seq_len: int, ): - BLOCK = 128 if not is_tesla() else 64 - - nope_dim = kv_nope.shape[-1] - rope_dim = kv_rope.shape[-1] - if nope_dim >= 512: - BLOCK = 64 if not is_tesla() else 32 - else: - BLOCK = 128 if not is_tesla() else 64 - + nope_dim = sampled_compressed_kv_nope.shape[-1] + rope_dim = sampled_k_rope.shape[-1] + assert rope_dim == 64 batch = b_seq_len.shape[0] - max_input_len = max_value_in_b_seq_len + BLOCK = 64 if not is_tesla() else 32 + num_warps = 8 grid = ( batch, - triton.cdiv(max_input_len, BLOCK), + triton.cdiv(max_kv_seq_len, BLOCK), ) - num_warps = 4 if nope_dim <= 64 else 8 + + all_compressed_kv = all_compressed_kv.view(all_compressed_kv.shape[0], all_compressed_kv.shape[2]) + sampled_compressed_kv_nope = sampled_compressed_kv_nope.view(sampled_compressed_kv_nope.shape[0], nope_dim) + sampled_k_rope = sampled_k_rope.view(sampled_k_rope.shape[0], rope_dim) + assert triton.next_power_of_2(nope_dim) == nope_dim + assert triton.next_power_of_2(rope_dim) == rope_dim _sample_kv_kernel[grid]( - kv_input, - kv_scale, - kv_nope, - kv_rope, - k_scale, - b_kv_start_loc, - b_seq_len, - req_to_token_indexs, - b_req_idx, - kv_input.stride(0), - kv_scale.stride(0) if kv_scale is not None else 0, - kv_nope.stride(0), - kv_rope.stride(0), - req_to_token_indexs.stride(0), - HAS_SCALE=kv_scale is not None, - BLOCK_M=BLOCK, - BLOCK_DMODEL=nope_dim, - BLOCK_ROPE_DMODEL=rope_dim, + all_compressed_kv=all_compressed_kv, + stride_all_s=all_compressed_kv.stride(0), + stride_all_d=all_compressed_kv.stride(1), + sampled_compressed_kv_nope=sampled_compressed_kv_nope, + stride_nope_s=sampled_compressed_kv_nope.stride(0), + stride_nope_d=sampled_compressed_kv_nope.stride(1), + sampled_k_rope=sampled_k_rope, + stride_rope_s=sampled_k_rope.stride(0), + stride_rope_d=sampled_k_rope.stride(1), + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + stride_req_to_tokens_b=req_to_token_indexs.stride(0), + b_req_idx=b_req_idx, + BLOCK_SEQ=BLOCK, + BLOCK_NOPE_DIM=nope_dim, + BLOCK_ROPE_DIM=rope_dim, num_warps=num_warps, num_stages=1, ) diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index 26bfc865e4..adb749c40e 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] self.hidden_size = network_config["hidden_size"] return diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 4a5bf2e961..1f0815c3db 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.eh_proj.weight", diff --git a/lightllm/models/gemma3/layer_infer/post_layer_infer.py b/lightllm/models/gemma3/layer_infer/post_layer_infer.py index 22dc595059..57b21844ec 100644 --- a/lightllm/models/gemma3/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/post_layer_infer.py @@ -4,7 +4,7 @@ class Gemma3PostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = 1e-6 return diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index dc8a46ad91..3543786f69 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -5,8 +5,8 @@ class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32) self.boi_token_index: int = 255_999 self.eoi_token_index: int = 256_000 diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index d4bd8c3fa6..1f386625bf 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -1,12 +1,6 @@ import torch -import torch.functional as F import torch.distributed as dist import torch.nn as nn -import numpy as np -from typing import Tuple -from functools import partial -import triton - from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.distributed import all_reduce from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight @@ -18,8 +12,8 @@ class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] self.tp_v_head_num_ = network_config["num_key_value_heads"] self.eps_ = 1e-6 diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 17e65268cc..858937d8c1 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="language_model.model.embed_tokens.weight", diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 1e7ceeb42a..e7808c412c 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -9,10 +9,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index dc4f03b7e1..9931c31713 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -1,7 +1,5 @@ import os -import re import json -import numpy as np import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer @@ -14,8 +12,6 @@ from lightllm.models.gemma3.layer_weights.pre_and_post_layer_weight import Gemma3PreAndPostLayerWeight from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel -from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer -from lightllm.models.llava.layer_weights.pre_and_post_layer_weight import LlavaPreAndPostLayerWeight from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem from lightllm.server.core.objs import SamplingParams from lightllm.common.build_utils import repair_config diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index ce9737820e..468d471d2c 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -11,8 +11,8 @@ class Gemma_2bPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) tp_vob_ids = np.linspace(0, network_config["vocab_size"], self.tp_world_size_ + 1, dtype=np.int64) self.vob_start_id_, self.vob_end_id_ = int(tp_vob_ids[self.tp_rank_]), int(tp_vob_ids[self.tp_rank_ + 1]) self.normfactor = network_config["hidden_size"] ** 0.5 diff --git a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py index 35ddaef343..2ed325659d 100644 --- a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py @@ -16,8 +16,8 @@ class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1 self.tp_v_head_num_ = network_config["num_key_value_heads"] return diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index d5d0438fa3..6e052caa63 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 1916bd095c..9102ce6775 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Gemma_2bTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_qkv(self): diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 93cd7413ba..d80eefd16e 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -1,22 +1,15 @@ import torch -from torch import nn -from torch.nn import functional as F -import numpy as np -from functools import partial -from typing import Optional - from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, LlamaInferStateInfo +from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) class GptOssTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.hidden_size = self.network_config_["hidden_size"] self.alpha = 1.702 self.limit = 7.0 @@ -24,22 +17,17 @@ def __init__(self, layer_num, network_config, mode=[]): self.sliding_window = network_config["sliding_window"] self.head_dim_ = network_config["head_dim"] - def _bind_attention(self): - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._context_attention_kernel = self._context_sliding_attention_flashattention - self._token_attention_kernel = self._token_sliding_attention_flashattention - def _bind_norm(self): self._att_norm = self._att_norm self._ffn_norm = self._ffn_norm return - def _att_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _att_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_) return out - def _ffn_norm(self, input, infer_state, layer_weight) -> torch.Tensor: + def _ffn_norm(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: out = self.alloc_tensor(input.shape, input.dtype) out = self._gpt_oss_rmsnorm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_) return out @@ -51,9 +39,7 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): hidden_states = hidden_states * torch.rsqrt(variance + eps) return (weight * hidden_states).to(input_dtype) # main diff with Llama - def _ffn( - self, input, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight - ) -> torch.Tensor: + def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -68,78 +54,61 @@ def _ffn( ) return hidden_states.view(num_tokens, hidden_dim) - def _context_sliding_attention_flashattention( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: LlamaInferStateInfo, + layer_weight: GptOssTransformerLayerWeight, + out=None, ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, - ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor - def _token_sliding_attention_flashattention( - self, q, infer_state: FlashAttentionStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None + def _token_attention_kernel( + self, q: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: GptOssTransformerLayerWeight, out=None ): if self.network_config_["layer_types"][self.layer_num_] == "sliding_attention": window_size = (self.sliding_window - 1, self.sliding_window - 1) + use_sliding_window = True else: window_size = (-1, -1) + use_sliding_window = False - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - softmax_scale=sm_scale, - causal=True, - window_size=window_size, - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - sinks=layer_weight.attn_sinks.weight, + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, + k=_k, + v=_v, + att_control=AttControl( + use_sliding_window=use_sliding_window, + sliding_window=window_size, + use_att_sink=True, + sink_weight=layer_weight.attn_sinks.weight, + ), + alloc_func=self.alloc_tensor, ) - return o + o_tensor = o_tensor.view(q.shape) + return o_tensor diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index f6a841b1aa..c5c14b08e6 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -17,10 +17,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_moe(self): diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index 34a017b316..dc5f2abdfe 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -19,4 +19,9 @@ class GptOssTpPartModel(LlamaTpPartModel): def __init__(self, kvargs): super().__init__(kvargs) - assert get_env_start_args().enable_fa3, "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_prefill_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" + assert ( + get_env_start_args().llm_decode_att_backend[0] == "fa3" + ), "For now GPT-OSS type model only support flashattention-3" diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index a2fc91dc46..6ef81122d3 100755 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -1,13 +1,9 @@ -import torch -import math -import numpy as np - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight class InternlmTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/internlm/model.py b/lightllm/models/internlm/model.py index 78ac7117e1..50adbb3f9f 100644 --- a/lightllm/models/internlm/model.py +++ b/lightllm/models/internlm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index b40330aa3d..3ed7004c12 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py index a675558632..a05e977f16 100755 --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Internlm2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index b20b9c4955..59caf40d6b 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.tok_embeddings.weight", data_type=self.data_type_, diff --git a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py index 7d76d202ae..21a4c2e6b5 100644 --- a/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py @@ -13,8 +13,8 @@ def rename_weight_keys(weights): class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -24,8 +24,8 @@ def load_hf_weights(self, weights): class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -35,8 +35,8 @@ def load_hf_weights(self, weights): class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py deleted file mode 100644 index 9f71cbbc56..0000000000 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_current_device_id -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index -from lightllm.common.basemodel.batch_objs import ModelInput -from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy - - -class FlashAttentionStateInfo(LlamaInferStateInfo): - _shared_page_table_buffer = None - - def __init__(self): - super().__init__() - - @classmethod - def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): - if cls._shared_page_table_buffer is None: - cls._shared_page_table_buffer = [ - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()), - ] - return cls._shared_page_table_buffer - - def _init_flash_attention_state(self, model): - if self.is_prefill: - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=self.input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) - else: - # Meta information of flashattention for decoding - self.cu_seqlens_q = self.b1_cu_q_seq_len.int() - self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - max_seq_len_k = self.max_kv_seq_len - args_mtp_step = get_env_start_args().mtp_step - att_batch_size = self.batch_size // (args_mtp_step + 1) - if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch - ) - self.page_table = page_buffer[self.microbatch_index][ - : att_batch_size * model.graph_max_len_in_batch - ].reshape(att_batch_size, model.graph_max_len_in_batch) - else: - self.page_table = torch.empty( - (att_batch_size, self.max_len_in_batch), dtype=torch.int32, device=self.input_ids.device - ) - - page_table_copy( - page_table=self.page_table[:, :max_seq_len_k], - req_to_token_indexs=model.req_manager.req_to_token_indexs, - b_req_idx=self.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], - ) - if args_mtp_step > 0: - self.b_att_seq_len = self.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() - else: - self.b_att_seq_len = self.b_seq_len - - if "offline_calibration_fp8kv" in model.mode: - if self.is_prefill: - device = self.input_ids.device - # q_scale和token_batch_ids在对q做per head量化使用,为了节省资源在推理外部初始化 - self.q_scale = torch.empty( - (self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device - ) - self.token_batch_ids = torch.repeat_interleave( - torch.arange(self.batch_size, device=device), self.b_q_seq_len - ) - - offline_scales = self.mem_manager.scales - head_num = self.mem_manager.head_num - # 为了减少推理计算量,在推理外部初始化k_descale和v_descale - self.k_descale = ( - offline_scales[:, :head_num] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - self.v_descale = ( - offline_scales[:, head_num:] - .view(-1, 1, head_num) - .expand(offline_scales.shape[0], self.batch_size, head_num) - if offline_scales is not None - else torch.ones( - (self.mem_manager.layer_num, self.batch_size, head_num), - dtype=torch.float32, - device=self.input_ids.device, - ) - ) - return - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self._init_flash_attention_state(model) - return diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py deleted file mode 100644 index 7f9beac1db..0000000000 --- a/lightllm/models/llama/flashinfer_struct.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import torch -import numpy as np -import torch.distributed as dist -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index - - -class LlamaFlashInferStateInfo(LlamaInferStateInfo): - def __init__(self): - super().__init__() - self.prefill_wrapper = None - self.decode_wrapper = None - self.flashinfer_extra_state = None - - def init_some_extra_state(self, model): - super().init_some_extra_state(model) - self.flashinfer_extra_state = model.flashinfer_extra_state - - import flashinfer - - if not self.is_prefill: - if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device - ) - if self.batch_size <= model.graph_max_batch_size: - self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length - ] - else: - self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) - self.kv_starts = self.b1_cu_kv_seq_len.int() - if self.decode_wrapper is None: - self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=self.kv_starts, - paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, - ) - self.decode_wrapper.plan( - self.kv_starts, - self.kv_indices, - self.kv_last_page_len_buffer, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - else: - if get_env_start_args().enable_flashinfer_prefill: - q_starts = self.b1_cu_q_seq_len.int() - kv_starts = self.b1_cu_kv_seq_len.int() - kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=self.input_ids.device) - kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, - dtype=torch.int32, - device=self.input_ids.device, - ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts[:-1], - self.max_kv_seq_len, - kv_indices, - ) - self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_extra_state.workspace_buffer, - qo_indptr_buf=q_starts, - paged_kv_indptr_buf=kv_starts, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len, - ) - self.prefill_wrapper.plan( - q_starts, - kv_starts, - kv_indices, - kv_last_page_len, - self.flashinfer_extra_state.tp_q_head_num, - self.flashinfer_extra_state.tp_kv_head_num, - self.flashinfer_extra_state.head_dim, - 1, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=self.flashinfer_extra_state.q_data_type, - kv_data_type=self.flashinfer_extra_state.kv_data_type, - ) - return - - def copy_for_cuda_graph(self, new_infer_state): - super().copy_for_cuda_graph(new_infer_state) - if get_env_start_args().enable_flashinfer_decode and not self.is_prefill: - self.decode_wrapper.plan( - new_infer_state.kv_starts, - new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, - new_infer_state.flashinfer_extra_state.tp_q_head_num, - new_infer_state.flashinfer_extra_state.tp_kv_head_num, - new_infer_state.flashinfer_extra_state.head_dim, - 1, - q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, - kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, - non_blocking=True, - ) - return diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index 3bba439767..fe6ca392a2 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -14,7 +14,6 @@ def init_some_extra_state(self, model): super().init_some_extra_state(model) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len position_ids = self.position_ids self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 7c7b0ea39b..8bc10d623c 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -13,8 +13,8 @@ class LlamaPostLayerInfer(PostLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.eps_ = network_config["rms_norm_eps"] return diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index ddb99e2627..f4f150b173 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -10,8 +10,8 @@ class LlamaPreLayerInfer(PreLayerInferTpl): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py old mode 100755 new mode 100644 index b08b2aa1fd..2a9a543196 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -1,51 +1,23 @@ import torch import triton -import torch.functional as F import torch.distributed as dist -import numpy as np -from typing import Tuple from functools import partial - from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, - context_attention_fwd_ppl_int8kv, -) -from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k -from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd -from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd - from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 from lightllm.common.basemodel import TransformerLayerInferTpl -from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None logger = init_logger(__name__) -from lightllm.utils.sgl_utils import flash_attn_with_kvcache - class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.eps_ = network_config["rms_norm_eps"] self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ self.tp_k_head_num_ = max(network_config["num_key_value_heads"] // self.tp_world_size_, 1) @@ -58,7 +30,6 @@ def __init__(self, layer_num, network_config, mode=[]): def _bind_func(self): self._bind_norm() - self._bind_attention() return def _bind_norm(self): @@ -66,125 +37,34 @@ def _bind_norm(self): self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) return - def _bind_attention(self): - if get_env_start_args().enable_fa3: - if "offline_calibration_fp8kv" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention_fp8, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - elif "export_fp8kv_calibration" in self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) - elif not self.mode: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashattention, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashattention, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") - return - elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) - else: - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) - if "ppl_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int8kv_flashdecoding_diverse" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int8kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self - ) - elif "ppl_int4kv_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_int4kv_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int4kv, self) - elif "ppl_fp16" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "ppl_fp16_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_ppl_fp16_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_int8kv" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_int8kv, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_int8kv, self) - elif "offline_calibration_fp8kv" in self.mode: - assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_fp8kv, self) - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel_fp8, self - ) - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self - ) - elif "triton_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_attention" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_gqa_attention_normal, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_flashdecoding" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "triton_gqa_flashdecoding_vsm" in self.mode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self - ) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - elif "export_fp8kv_calibration" in self.mode: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self) - self._copy_kv_to_mem_cache = partial( - LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self - ) - elif not self.mode: - if get_env_start_args().enable_flashinfer_decode: - self._token_attention_kernel = partial( - LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self - ) - else: - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) - self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - else: - raise Exception(f"Unsupported mode: {self.mode}") + def _context_attention_kernel( + self, + q: torch.Tensor, + kv: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + alloc_func=self.alloc_tensor, + ) + o_tensor = o_tensor.view(q.shape) + return o_tensor - return + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: LlamaTransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + o_tensor = infer_state.decode_att_state.decode_att(q=_q, k=_k, v=_v, alloc_func=self.alloc_tensor) + return o_tensor.view(q.shape) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight @@ -241,163 +121,6 @@ def _tpsp_get_qkv( return q, cache_kv - def _context_attention_flashinfer_kernel_fp8( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - - def _context_attention_flashinfer_kernel( - self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) - infer_state.prefill_wrapper.run( - q.view(q.shape[0], -1, self.head_dim_), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(q.shape[0], -1, self.head_dim_), - ) - return o_tensor - - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _context_attention_kernel_ppl_int8kv( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - batch_size = infer_state.b_seq_len.shape[0] - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv_scale = infer_state.mem_manager.scale_buffer[self.layer_num_] - max_seq_len = infer_state.max_seq_len - kv_dequant = self.alloc_tensor( - (batch_size, kv.shape[1], max_seq_len, kv.shape[2]), device=q.device, dtype=q.dtype - ) - destindex_copy_dequantize_kv( - kv, - kv_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_seq_len, - infer_state.b_req_idx, - max_seq_len, - kv_dequant, - ) - context_attention_fwd_ppl_int8kv( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv_dequant[:, 0 : self.tp_k_head_num_, :, :], - kv_dequant[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - infer_state.b_ready_cache_len, - ) - return o_tensor - - def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _context_attention_flashattention_fp8( - self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): - q, q_scale = q_per_head_fp8_quant( - q.view(q.shape[0], self.tp_k_head_num_, -1), - infer_state.b_seq_len, - infer_state.cu_seqlens_q, - infer_state.q_scale, - infer_state.token_batch_ids, - ) - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.q_max_seq_len, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale, - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: @@ -486,453 +209,6 @@ def _tpsp_ffn( # gate_out, up_out = None, None # return ffn2_out - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _copy_kv_to_mem_cache_with_calibration(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - mem_manager.update_calibration_data(buffer, self.layer_num_) - return - - def _copy_kv_to_mem_cache_int8kv(self, buffer, mem_index, mem_manager): - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _copy_kv_to_mem_cache_fp8kv(self, buffer, mem_index, mem_manager): - scales = mem_manager.scales - destindex_copy_kv_fp8( - buffer, - mem_index, - scales[self.layer_num_] if scales is not None else None, - mem_manager.kv_buffer[self.layer_num_].view(torch.float8_e4m3fn), - ) - return - - def _copy_kv_to_mem_cache_ppl_int8kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_quantize_kv - - destindex_copy_quantize_kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): - from lightllm.models.llama.triton_kernel.ppl_int4kv_copy_kv import destindex_copy_int4kv - - destindex_copy_int4kv( - buffer, mem_index, mem_manager.kv_buffer[self.layer_num_], mem_manager.scale_buffer[self.layer_num_] - ) - return - - def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn) - v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn) - offline_scales = infer_state.mem_manager.scales_list - k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None - v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (k, v), - k_scale=k_descale, - v_scale=v_descale, - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - - def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) - infer_state.decode_wrapper.run( - q.view(calcu_shape1), - (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), - out=o_tensor.view(calcu_shape1), - ) - return o_tensor - - def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) - - token_att_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o_tensor - - def _token_decode_gqa_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - # 对 gqa模型进行推理优化的代码 - from ..triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - gqa_decode_attention_fwd( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - ) - return o_tensor - - def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), q.dtype) - token_att_fwd_int8k( - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - prob = self.alloc_tensor(att_m_tensor.shape, att_m_tensor.dtype) - token_softmax_fwd( - att_m_tensor, infer_state.b_start_loc, infer_state.b_seq_len, prob, infer_state.max_len_in_batch - ) - att_m_tensor = None - - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - token_att_fwd2_int8v( - prob, - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - o_tensor.view(calcu_shape1), - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - prob = None - return o_tensor - - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.llama.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_gqa_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - # 对 gqa 模型进行推理优化的代码 - from ..triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return gqa_token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, - # at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - light_ops.group8_int8kv_decode_attention( - o_tensor.view(calcu_shape1), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - from lightllm_ppl_fp16_kernel import fp16_decode_attention - - # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, - # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) - fp16_decode_attention( - o_tensor.view(calcu_shape1), - 1.0 / (self.head_dim_ ** 0.5), - q.view(calcu_shape1), - infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ], - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - return o_tensor - - def _token_decode_attention_ppl_fp16_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_fp16_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int8kv_flashdecoding_diverse( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( - token_decode_attention_flash_decoding, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_ppl_int4kv_flashdecoding( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.ppl_int4kv_flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_gqa_flashdecoding_vsm( - self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None - ): - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import ( - gqa_token_decode_attention_flash_decoding_vsm, - ) - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_) - return gqa_token_decode_attention_flash_decoding_vsm( - q.view(q_shape), - cache_k, - cache_v, - infer_state, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) - - def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ - ) - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) - k_descale, v_descale = None, None # disable quantization - Lq = q.shape[-1] - sm_scale = 1.0 / (Lq ** 0.5) - o = flash_attn_with_kvcache( - q=q, - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_att_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=infer_state.max_q_seq_len, - softmax_scale=sm_scale, - causal=True, - window_size=(-1, -1), - softcap=0.0, - k_descale=k_descale, - v_descale=v_descale, - return_softmax_lse=False, - ) - return o - - def _token_decode_attention_flashattention_fp8( - self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None - ): - cache_k = ( - (infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :]) - .reshape(-1, 1, self.tp_k_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - cache_v = ( - ( - infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - ) - .reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) - .view(torch.float8_e4m3fn) - ) - q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * self.tp_k_head_num_, -1), use_per_token_if_dynamic=True) - o = flash_attn_with_kvcache( - q=q.view(-1, self.tp_q_head_num_, self.head_dim_), - k_cache=cache_k, - v_cache=cache_v, - page_table=infer_state.page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=infer_state.cu_seqlens_q, - cu_seqlens_k_new=infer_state.cu_seqlens_k, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(infer_state.batch_size, self.tp_k_head_num_), - k_descale=infer_state.k_descale[self.layer_num_], - v_descale=infer_state.v_descale[self.layer_num_], - return_softmax_lse=False, - ) - return o - def overlap_tpsp_token_forward( self, input_embdings: torch.Tensor, diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index ea59d24dfc..7e9ff41673 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 6b92272ee7..197116d99c 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -11,10 +11,9 @@ def __init__( layer_num, data_type, network_config, - mode=[], quant_cfg=None, ): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 95465a9e6c..c104ebccc9 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -9,39 +9,15 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo -from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo from lightllm.common.basemodel import TpPartBaseModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id logger = init_logger(__name__) -class LlamaFlashInferStateExtraInfo: - def __init__(self, model): - tp_world_size = get_dp_world_size() - self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size - self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) - head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] - self.head_dim = model.config.get("head_dim", head_dim) - self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - self.q_data_type = model.data_type - self.kv_data_type = torch.float8_e4m3fn if "offline_calibration_fp8kv" in model.mode else model.data_type - - @ModelRegistry("llama") class LlamaTpPartModel(TpPartBaseModel): # weight class @@ -57,9 +33,6 @@ class LlamaTpPartModel(TpPartBaseModel): infer_state_class = LlamaInferStateInfo def __init__(self, kvargs): - self.enable_flashinfer = ( - get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode - ) super().__init__(kvargs) return @@ -94,13 +67,6 @@ def _init_mem_manager(self): ) return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo - elif self.enable_flashinfer: - self.infer_state_class = LlamaFlashInferStateInfo - self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self) - def _init_custom(self): """ 模型特殊的一些初始化 diff --git a/lightllm/models/llama/triton_kernel/flash_decoding.py b/lightllm/models/llama/triton_kernel/flash_decoding.py deleted file mode 100644 index e47e308864..0000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None, alloc_tensor_func=torch.empty -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from .flash_decoding_stage1 import flash_decode_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" - ) - - flash_decode_stage1( - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ, - ) - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py deleted file mode 100644 index 86a3af103d..0000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage1.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import triton -import triton.language as tl - -@triton.jit -def _fwd_kernel_flash_decode_stage1( - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - gqa_group_size, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - cur_kv_head = cur_head // gqa_group_size - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - - q = tl.load(Q + off_q) - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) - k_loc = k_loc.to(tl.int64) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) - v = tl.load(V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - - cur_max_logic = tl.max(att_value, axis=0) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale - acc += tl.sum(exp_logic[:, None] * v, axis=0) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block - tl.store(Mid_O + off_mid_o, acc / sum_exp) - tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) - return - - -@torch.no_grad() -def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq): - BLOCK_SEQ = block_seq - BLOCK_N = 16 - assert BLOCK_SEQ % BLOCK_N == 0 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) - gqa_group_size = q.shape[1] // k.shape[1] - - _fwd_kernel_flash_decode_stage1[grid]( - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), - gqa_group_size, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK_N, - num_warps=1, - num_stages=2, - ) - return \ No newline at end of file diff --git a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py b/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py deleted file mode 100644 index 81227f967b..0000000000 --- a/lightllm/models/llama/triton_kernel/flash_decoding_stage2.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_flash_decode_stage2( - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - O, #[batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, - stride_obs, stride_oh, stride_od, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ - - sum_exp = 0.0 - max_logic = -float("inf") - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) - new_max_logic = tl.maximum(tlogic, max_logic) - - old_scale = tl.exp(max_logic - new_max_logic) - acc *= old_scale - exp_logic = tl.exp(tlogic - new_max_logic) - acc += exp_logic * tv - sum_exp = sum_exp * old_scale + exp_logic - max_logic = new_max_logic - - tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) - return - - -@torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): - Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} - batch, head_num = mid_out.shape[0], mid_out.shape[1] - grid = (batch, head_num) - - _fwd_kernel_flash_decode_stage2[grid]( - B_Seqlen, mid_out, mid_out_logexpsum, O, - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2), - O.stride(0), O.stride(1), O.stride(2), - BLOCK_SEQ=block_seq, - BLOCK_DMODEL=Lk, - num_warps=4, - num_stages=2, - ) - return \ No newline at end of file diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py deleted file mode 100644 index 7ba0f3b31b..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_copy_kv.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_int4_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM // 2) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data_0 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - src_data_1 = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :] * 2 + 1, - mask=offs_g[:, None] < group_size, - other=0.0, - ) - - abs_data_0 = tl.abs(src_data_0) - abs_data_1 = tl.abs(src_data_1) - - data_scale = (tl.maximum(tl.max(abs_data_0, axis=1), tl.max(abs_data_1, axis=1)) / 7.0).to(Out_scale.dtype.element_ty) - q_src_data_0 = (src_data_0 / data_scale[:, None]).to(tl.int8) - q_src_data_0 = tl.where(q_src_data_0 > 7, 7, q_src_data_0) - q_src_data_0 = tl.where(q_src_data_0 < -7, -7, q_src_data_0) - - q_src_data_1 = (src_data_1 / data_scale[:, None]).to(tl.int8) - q_src_data_1 = tl.where(q_src_data_1 > 7, 7, q_src_data_1) - q_src_data_1 = tl.where(q_src_data_1 < -7, -7, q_src_data_1) - - low_4 = ((q_src_data_0 & 0x80) >> 4) | (q_src_data_0 & 0xF) - high_4 = (((q_src_data_1 & 0x80) >> 4) | (q_src_data_1 & 0xF)) << 4 - - # tl.device_print(low_4) - # tl.device_print(high_4) - - out_data = low_4 | high_4 - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, out_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_int4kv(K, DestLoc, Out, Out_scale): - # seq_len = DestLoc.shape[0] - # head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - # grid = (seq_len, head_num) - # num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view( - Out.shape[0], Out.shape[1], group_size, group_dim // 2 - ) # OUt 是 int8 类型, 两个int4组一个int8,所以 group_dim // 2 - - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_copy_int4_kv - - group8_copy_int4_kv(Out, Out_scale, K, DestLoc, 4) - - # _fwd_kernel_destindex_copy_quantize_int4_kv[grid]( - # K, - # DestLoc, - # Out, - # Out_scale, - # K.stride(0), - # K.stride(1), - # K.stride(2), - # K.stride(3), - # Out.stride(0), - # Out.stride(1), - # Out.stride(2), - # Out.stride(3), - # Out_scale.stride(0), - # Out_scale.stride(1), - # Out_scale.stride(2), - # group_size, - # BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - # BLOCK_GROUP_DIM=group_dim, - # num_warps=num_warps, - # num_stages=1, - # ) - return - - -def test2(): - import time - - src = torch.randn((1, 1, 8), dtype=torch.float16).cuda() - src[0, 0, :] = torch.tensor([1, -2, 2, 0, 4, 5, 6, 7]).cuda() - dest_loc = torch.arange(0, 1, dtype=torch.int32).cuda() - value_dest = torch.randn((1, 1, 4), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((1, 1, 1), dtype=torch.float16).cuda() - - destindex_copy_int4kv(src, dest_loc, value_dest, scale_dest) - - print(value_dest) - print(scale_dest) - - -if __name__ == "__main__": - test2() diff --git a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py b/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py deleted file mode 100644 index 1e324bcc0b..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_int4kv_flash_decoding.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch - - -def token_decode_attention_flash_decoding( - q, - infer_state, - q_head_num, - head_dim, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, - alloc_tensor_func=torch.empty, -): - BLOCK_SEQ = 256 - batch_size = infer_state.batch_size - max_len_in_batch = infer_state.max_len_in_batch - calcu_shape1 = (batch_size, q_head_num, head_dim) - - from lightllm_ppl_int4kv_flashdecoding_kernel import group8_int4kv_flashdecoding_stage1 - from .flash_decoding_stage2 import flash_decode_stage2 - - o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float16, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float16, device="cuda" - ) - - group8_int4kv_flashdecoding_stage1( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) - return o_tensor diff --git a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py b/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py deleted file mode 100644 index 3d9a490f47..0000000000 --- a/lightllm/models/llama/triton_kernel/ppl_quant_copy_kv.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_g, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_g, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_g, - group_size, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_head = tl.program_id(1) - - offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) - - src_data = tl.load( - K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], - mask=offs_g[:, None] < group_size, - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty) - q_src_data = (src_data / data_scale[:, None]).to(tl.int8) - - o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g - tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size) - tl.store(os_ptrs, data_scale, mask=offs_g < group_size) - return - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - quant_group_dim = 8 - - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (seq_len, head_num) - num_warps = 1 - - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - - K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) - Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - K.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - group_size, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_dequantize_kv( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - stride_kv_b, - stride_kv_h, - stride_kv_g, - stride_kv_d, - stride_o_bh, - stride_o_l, - stride_o_g, - stride_o_d, - stride_s_b, - stride_s_h, - stride_s_g, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - group_size, - head_num: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_GROUP_NUM: tl.constexpr, - BLOCK_GROUP_DIM: tl.constexpr, -): - cur_group = tl.program_id(0) - start_m = tl.program_id(1) - cur_bh = tl.program_id(2) - cur_batch = cur_bh // head_num - cur_head = cur_bh % head_num - - block_start_loc = BLOCK_SIZE * start_m - cur_batch_req_idx = tl.load(b_req_idx + cur_batch) - cur_seq_len = tl.load(b_seq_len + cur_batch) - - # initialize offsets - offs_kv_loc = block_start_loc + tl.arange(0, BLOCK_SIZE) - - # offs_g = tl.arange(0, BLOCK_GROUP_NUM) - offs_d = tl.arange(0, BLOCK_GROUP_DIM) - - kv_loc = tl.load( - req_to_token_indexs + cur_batch_req_idx * stride_req_to_tokens_b + offs_kv_loc, mask=offs_kv_loc < cur_seq_len - ).to(tl.int64) - offs_kv = kv_loc[:, None] * stride_kv_b + cur_head * stride_kv_h + cur_group * stride_kv_g + offs_d[None, :] - - src_data = tl.load( - mem_kv_buffer + offs_kv, - mask=offs_kv_loc[:, None] < cur_seq_len, - other=0.0, - ).to(Out.dtype.element_ty) - - s_ptrs = mem_kv_scale + kv_loc * stride_s_b + cur_head * stride_s_h + cur_group * stride_s_g - data_scale = tl.load( - s_ptrs, - mask=offs_kv_loc < cur_seq_len, - ) - - out_data = src_data * data_scale[:, None] - o_ptrs = Out + cur_bh * stride_o_bh + offs_kv_loc[:, None] * stride_o_l + cur_group * stride_o_g + offs_d[None, :] - tl.store(o_ptrs, out_data, mask=offs_kv_loc[:, None] < cur_seq_len) - return - - -@torch.no_grad() -def destindex_copy_dequantize_kv( - mem_kv_buffer, mem_kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_len_in_batch, Out -): - batch_size = b_seq_len.shape[0] - head_num = mem_kv_buffer.shape[1] - head_dim = mem_kv_buffer.shape[2] - quant_group_dim = 8 - BLOCK_SIZE = 128 - group_size = head_dim // quant_group_dim - group_dim = quant_group_dim - assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" - grid = (group_size, triton.cdiv(max_len_in_batch, BLOCK_SIZE), batch_size * head_num) - num_warps = 1 - mem_kv_buffer = mem_kv_buffer.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], group_size, group_dim)) - mem_kv_scale = mem_kv_scale.view((mem_kv_buffer.shape[0], mem_kv_buffer.shape[1], -1)) - Out = Out.view(Out.shape[0] * Out.shape[1], -1, group_size, group_dim) - - _fwd_kernel_destindex_copy_dequantize_kv[grid]( - mem_kv_buffer, - mem_kv_scale, - req_to_token_indexs, - b_seq_len, - b_req_idx, - Out, - mem_kv_buffer.stride(0), - mem_kv_buffer.stride(1), - mem_kv_buffer.stride(2), - mem_kv_buffer.stride(3), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out.stride(3), - mem_kv_scale.stride(0), - mem_kv_scale.stride(1), - mem_kv_scale.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - group_size, - head_num=head_num, - BLOCK_SIZE=BLOCK_SIZE, - BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), - BLOCK_GROUP_DIM=group_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test2(): - import time - - B, N_CTX, H, D = 1, 3, 12, 128 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) - scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) - print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -def torch_dequant(kv, kv_scale, o, b_req_idx, b_seq_len, req_to_token_indexs): - - batch = b_req_idx.shape[0] - for i in range(batch): - req_idx = b_req_idx[i] - seq_len = b_seq_len[i] - print(seq_len, b_seq_len) - kv_loc = req_to_token_indexs[req_idx, :seq_len] - head_num = kv.shape[1] - cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, 8).to(o.dtype) - cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) - out = cur_kv * cur_scale - o[i, :seq_len, :, :] = out.reshape(out.shape[0], out.shape[1], -1) - - -def test3(): - import time - import numpy as np - - Z, H, N_CTX, D_HEAD = 1, 16, 3, 128 - dtype = torch.bfloat16 - kv = torch.empty((Z * N_CTX + 100, 2 * H, D_HEAD), dtype=torch.int8, device="cuda") - kv_scale = torch.randn((Z * N_CTX + 100, 2 * H, D_HEAD // 8), dtype=dtype, device="cuda") - out = torch.empty((Z, 2 * H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - torch_out = torch.empty((Z, N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.empty((1000, N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - for i in range(Z): - seq_len = N_CTX - i * 100 - b_seq_len[i] = seq_len - b_req_idx[i] = i - req_to_token_indexs[i][:seq_len] = ( - torch.tensor(np.arange(seq_len), dtype=torch.int32).cuda() + b_seq_len[0:i].sum() - ) - print(b_seq_len) - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch_dequant(kv, kv_scale, torch_out, b_req_idx, b_seq_len, req_to_token_indexs) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_dequantize_kv(kv, kv_scale, req_to_token_indexs, b_seq_len, b_req_idx, max_input_len, out) - torch.cuda.synchronize() - t2 = time.time() - print((t2 - t1)) - torch_out = torch_out.transpose(1, 2) - for i in range(Z): - print("max ", torch.max(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - print("mean ", torch.mean(torch.abs(torch_out - out)[i][:, : b_seq_len[i]])) - assert torch.allclose(torch_out[i][:, : b_seq_len[i]], out[i][:, : b_seq_len[i]], atol=1e-2, rtol=0) - # print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) - # cos = torch.nn.CosineSimilarity(0) - # print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test3() diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py deleted file mode 100644 index 243a8d1f66..0000000000 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_reduceV.py +++ /dev/null @@ -1,223 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_token_att2( - Prob, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen): - BLOCK = 128 - # BLOCK = 64 # for triton 2.0.0dev - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2[grid]( - prob, - v, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_token_att2_int8v( - Prob, - V, - V_scale, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, # B_Start_Loc 保存的是如果连续存储时候的累加输入和 - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_vsbs, - stride_vsh, - stride_vsd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = 0 - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - vs_offs = cur_kv_head * stride_vsh - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0) - v_loc = tl.load( - Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s, - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0.0, - ).to(tl.int64) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0 - ) - vs_value = tl.load( - V_scale + vs_offs + v_loc[:, None] * stride_vsbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value * vs_value, 0) - - acc = acc.to(Out.dtype.element_ty) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_att_fwd2_int8v(prob, v, v_scale, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, max_len_in_batch): - if max_len_in_batch < 512: - BLOCK = triton.next_power_of_2(max_len_in_batch) - else: - BLOCK = 512 - batch, head = B_req_idx.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - kv_group_num = prob.shape[0] // v.shape[1] - - _fwd_kernel_token_att2_int8v[grid]( - prob, - v, - v_scale, - out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - v_scale.stride(0), - v_scale.stride(1), - v_scale.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - out = torch.matmul(P, V) - - return out diff --git a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py index b4c070a1e6..3afcfb0a71 100644 --- a/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llava/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight @@ -14,8 +12,8 @@ def rename_weight_keys(weights): class LlavaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py index 0952468d0f..45023bdf8f 100644 --- a/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/pre_and_post_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) hidden_size = self.network_config_["hidden_size"] dim_model_base = self.network_config_.get("dim_model_base", hidden_size) self.lm_head_scale = hidden_size / dim_model_base diff --git a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py index 2bc5078382..c37b524fde 100755 --- a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MiniCPMTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index 59eef6daa7..d115c30ec1 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -4,7 +4,7 @@ class MistralTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config.get("head_dim", self.head_dim_) return diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index d32f51ae78..f09525c59f 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -1,5 +1,3 @@ -import os -import json import torch from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel import TpPartBaseModel @@ -8,7 +6,6 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num @@ -43,10 +40,6 @@ def _init_custom(self): self._init_to_get_rotary() return - def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: - self.infer_state_class = FlashAttentionStateInfo - def _init_mem_manager(self): # Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B head_dim = self.config["hidden_size"] // self.config["num_attention_heads"] diff --git a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py b/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index abcaf02b51..0000000000 --- a/lightllm/models/mistral/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - sliding_window, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # [SYM] mask outside of windows - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - qk = tl.where((start_n + offs_n[None, :]) > (offs_m[:, None] - sliding_window), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - -@torch.no_grad() -def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, sliding_window): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - sliding_window=sliding_window, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - - Z, H, N_CTX, D_HEAD = 4, 6, 1024, 128 - dtype = torch.float16 - Z = 3 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - - max_input_len = N_CTX - Z = 4 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = 512 - b_seq_len[1] = 1024 - b_seq_len[2] = 512 - b_seq_len[3] = 1024 - - for i in range(1, Z): - b_start_loc[i] = b_start_loc[i - 1] + b_seq_len[i - 1] - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att(q[start:end], k[start:end], v[start:end], 1, b_seq_len[i], H, D_HEAD) - start = end - torch_out.append(torch_o) - torch_out = torch.cat(torch_out, dim=0) - context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, 10) - print(o.shape, torch_out.shape) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py deleted file mode 100644 index a60fe970b3..0000000000 --- a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_init_att_window_info( - b_seq_len, - b_att_seq_len, - batch_size, - sliding_window, - BLOCK_SIZE: tl.constexpr, -): - cur_index = tl.program_id(0) - cur_start = cur_index * BLOCK_SIZE - offsets = cur_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < batch_size - - cur_seq_len = tl.load(b_seq_len + offsets, mask=mask) - b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window) - - tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask) - return - - -@torch.no_grad() -def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window): - # shape constraints - assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0] - - BLOCK_SIZE = 32 - num_warps = 1 - grid = (triton.cdiv(batch_size, BLOCK_SIZE),) - - _fwd_kernel_init_att_window_info[grid]( - b_seq_len, - b_att_seq_len, - batch_size=batch_size, - sliding_window=sliding_window, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py deleted file mode 100644 index bf9928f987..0000000000 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel( - Logics, - V, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - B_Att_Start_Loc, - B_Att_Seqlen, - stride_logic_h, - stride_logic_bs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_token_b, - stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 - kv_group_num, - sliding_window, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index - cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] - v_ptrs = V + off_v - - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] - - for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, - mask=(start_n + offs_n) < cur_att_seq_len, - other=other_kv_index, - ) # [64] - - qk = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=(start_n + offs_n) < cur_att_seq_len, - other=float("-inf"), - ) # [64] - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] - e_max = n_e_max - - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - -@torch.no_grad() -def token_softmax_reducev_fwd( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - sliding_window, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logics.shape[0] - grid = (batch, head) - kv_group_num = logics.shape[0] // v.shape[1] - - num_warps = 1 - _fwd_kernel[grid]( - logics, - v, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_att_start_loc, - b_att_seq_len, - logics.stride(0), - logics.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_tokens.stride(0), - req_to_tokens.stride(1), - 0, - kv_group_num, - sliding_window, - BLOCK_DMODEL=v.shape[-1], - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, - ) - return diff --git a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py index 5eac249bad..f890fbf663 100644 --- a/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/post_layer_infer.py @@ -4,6 +4,6 @@ class MistralMTPPostLayerInfer(LlamaPostLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index 25bea1aa60..dbe9b61c85 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -7,8 +7,8 @@ class MistralMTPPreLayerInfer(LlamaPreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def _mtp_context_forward( diff --git a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py index 5724f32af9..6d72ae2c38 100644 --- a/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/transformer_layer_infer.py @@ -10,8 +10,8 @@ class MistralMTPTransformerLayerInfer(MistralTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index 2fbc89cfd0..c9032f6fee 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="mtp.eh_proj.weight", diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 6607dbb704..08f280b06c 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class MistralMTPTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index a60375688b..44e66cff2d 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -3,14 +3,13 @@ import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mistral.layer_infer.transformer_layer_infer import MistralTransformerLayerInfer from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight class MixtralTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.num_local_experts = network_config["num_local_experts"] self.num_experts_per_tok = network_config["num_experts_per_tok"] self.renormalize = True diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index f425ad08ba..39e28d4655 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -8,12 +8,11 @@ class MixtralTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__( layer_num, data_type, network_config, - mode, quant_cfg=quant_cfg, ) return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 806c59365b..fd3d05e426 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,11 +1,5 @@ -import torch -from functools import partial from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -13,14 +7,8 @@ class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) - return - - def _bind_attention(self): - self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self) - self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): @@ -35,44 +23,3 @@ def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Ph infer_state.position_sin, ) return q, cache_kv - - def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager): - destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) - return - - def _context_attention_kernel( - self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None - ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, - ) - return o_tensor - - def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): - from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding - - cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] - cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ - :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ] - return token_decode_attention_flash_decoding( - q, - infer_state, - self.tp_q_head_num_, - self.head_dim_, - cache_k, - cache_v, - out=out, - alloc_tensor_func=self.alloc_tensor, - ) diff --git a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py index 91b2730917..db4906c19a 100755 --- a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py @@ -6,8 +6,8 @@ class Phi3TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def load_hf_weights(self, weights): diff --git a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py deleted file mode 100644 index ee04c3367b..0000000000 --- a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py +++ /dev/null @@ -1,433 +0,0 @@ -import torch - -import triton -import triton.language as tl -import math -import torch.nn.functional as F - -from lightllm.utils.device_utils import is_tesla - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - Req_to_tokens, - B_req_idx, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - kv_group_num, - b_prompt_cache_len, - head_dim: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, - other=0, - ).to(tl.int64) - off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load( - K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0 - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) - acc = acc * acc_scale[:, None] - # update acc - off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load( - V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0 - ) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs -): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - req_to_token_indexs, - b_req_idx, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), - kv_group_num=kv_group_num, - b_prompt_cache_len=b_prompt_cache_len, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_no_prompt_cache( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - head_dim, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (offs_d[:, None] < head_dim), - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (offs_d[None, :] < head_dim), - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 if not is_tesla() else 64 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - head_dim = Lq - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_no_prompt_cache[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - head_dim=head_dim, - BLOCK_M=BLOCK, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim, prompt_cache_len): - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - xv = xv.view(bs, seqlen + prompt_cache_len, num_head, head_dim) - mask_cache = torch.ones((seqlen, prompt_cache_len)).cuda().unsqueeze(0).unsqueeze(0).cuda() - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.0] = -100000000.0 - mask = torch.cat([mask_cache, mask], dim=-1) - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) - scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - - -def test(): - import torch - import numpy as np - - Z, H, N_CTX, D_HEAD = 10, 6, 500, 96 - dtype = torch.float16 - Z = 1 - q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - req_to_token_indexs = torch.zeros((10, Z * N_CTX + 7000), dtype=torch.int32, device="cuda") - max_input_len = N_CTX - Z = 1 - b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_prompt_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prompt_cache_len[0] = 0 - prompt_cache_len = 0 - - b_seq_len[0] = 500 - b_req_idx[0] = 0 - req_to_token_indexs[0][: prompt_cache_len + N_CTX] = torch.tensor( - np.arange(prompt_cache_len + N_CTX), dtype=torch.int32 - ).cuda() - - torch_out = [] - start = 0 - for i in range(Z): - end = start + b_seq_len[i] - torch_o = torch_att( - q[start:end], - k[start : end + prompt_cache_len], - v[start : end + prompt_cache_len], - 1, - b_seq_len[i], - H, - D_HEAD, - prompt_cache_len, - ) - start = end - torch_out.append(torch_o) - - torch_out = torch.cat(torch_out, dim=0) - - context_attention_fwd( - q, - k, - v, - o, - b_req_idx, - b_start_loc, - b_seq_len + prompt_cache_len, - b_prompt_cache_len, - max_input_len, - req_to_token_indexs, - ) - - # context_attention_fwd_no_prompt_cache( - # q, k, v, o, b_start_loc, b_seq_len, max_input_len - # ) - - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py b/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py deleted file mode 100644 index 4f31895ae0..0000000000 --- a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def _fwd_kernel_destindex_copy_kv( - K, - Dest_loc, - Out, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - - k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - - k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0) - tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - return - - -@torch.no_grad() -def destindex_copy_kv(K, DestLoc, Out): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_kv[grid]( - K, - DestLoc, - Out, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -@triton.jit -def _fwd_kernel_destindex_copy_quantize_kv( - K, - Dest_loc, - Out, - Out_scale, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_o_bs, - stride_o_h, - stride_o_d, - stride_os_bs, - stride_os_h, - stride_os_d, - head_num, - head_dim, - BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr, -): - cur_index = tl.program_id(0) - offs_h = tl.arange(0, BLOCK_HEAD) - offs_d = tl.arange(0, BLOCK_DMODEL) - - dest_index = tl.load(Dest_loc + cur_index) - src_data = tl.load( - K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], - mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), - other=0.0, - ) - abs_data = tl.abs(src_data) - data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] - q_src_data = (src_data / data_scale).to(tl.int8) - o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] - os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] - tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) - tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num)) - - -@torch.no_grad() -def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): - seq_len = DestLoc.shape[0] - head_num = K.shape[1] - head_dim = K.shape[2] - assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] - BLOCK_HEAD = triton.next_power_of_2(head_num) - BLOCK_DMODEL = triton.next_power_of_2(head_dim) - grid = (seq_len,) - num_warps = 1 - - _fwd_kernel_destindex_copy_quantize_kv[grid]( - K, - DestLoc, - Out, - Out_scale, - K.stride(0), - K.stride(1), - K.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - Out_scale.stride(0), - Out_scale.stride(1), - Out_scale.stride(2), - head_num, - head_dim, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_HEAD=BLOCK_HEAD, - num_warps=num_warps, - num_stages=1, - ) - return - - -def test1(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") - - for _ in range(10): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_kv(src, dest_loc, dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(dest - src))) - print("mean ", torch.mean(torch.abs(dest - src))) - assert torch.allclose(src, dest, atol=1e-2, rtol=0) - - -def test2(): - import time - - B, N_CTX, H, D = 32, 1024, 12, 96 - src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() - dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() - value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) - scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() - - for _ in range(10): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t1 = time.time() - for _ in range(1000): - destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) - torch.cuda.synchronize() - t2 = time.time() - - print("Time cost ", t2 - t1) - print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) - print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) - cos = torch.nn.CosineSimilarity(0) - print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) - - -if __name__ == "__main__": - test1() - test2() diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 7a4b2ca816..333870eb9d 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -1,7 +1,4 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np from typing import Tuple from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -12,8 +9,8 @@ class QwenTransformerLayerInfer(LlamaTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index 00f68eee69..bf9282a979 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 9afb964ad2..ac1bf91f4b 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class QwenTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): qkv_weight_name = f"transformer.h.{self.layer_num_}.attn.c_attn.weight" diff --git a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py index a8a57c02ed..6449430d9e 100644 --- a/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/pre_and_post_layer_weight.py @@ -2,6 +2,6 @@ class Qwen2PreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 6962818c49..9c3e2cb3a8 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class Qwen2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def _init_weight_names(self): super()._init_weight_names() diff --git a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py index 7cf6366223..3c974691d5 100644 --- a/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen2_reward/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,8 @@ class Qwen2RewardPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) del self.lm_head_weight_ self.score_up_weight_ = ROWMMWeight( weight_names="score.0.weight", diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 838590325c..747be932d9 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -4,13 +4,10 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.get_mrope_position_ids import get_mrope_position_triton -from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo from lightllm.utils.envs_utils import get_env_start_args class Qwen2VLInferStateInfo(LlamaInferStateInfo): - init_flash_attention_state_func = FlashAttentionStateInfo._init_flash_attention_state - def __init__(self): super().__init__() self.position_cos = None @@ -35,10 +32,6 @@ def init_some_extra_state(self, model): self.position_ids = self.position_ids.contiguous() self.position_cos = model._cos_cached[self.position_ids] self.position_sin = model._sin_cached[self.position_ids] - if get_env_start_args().enable_fa3: - self.max_seq_len = self.max_kv_seq_len - self.q_max_seq_len = self.max_q_seq_len - self.init_flash_attention_state_func(model) return def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: @@ -85,6 +78,6 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: position_ids=position_ids, b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, - b_start_loc=self.b_start_loc, + b_start_loc=self.b_q_start_loc, ) return position_ids diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index 19e17c36e8..298a77044c 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) mrope_section = network_config["rope_scaling"]["mrope_section"] self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 61dd06773f..dd4181fbfb 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -95,9 +95,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 20f135e761..5f0c91287d 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -13,8 +13,8 @@ class Qwen3TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] return diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 86b9e172a9..90b7810adf 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -5,8 +5,8 @@ class Qwen3TransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3/model.py b/lightllm/models/qwen3/model.py index 21e71e0e02..e48b36e0f7 100644 --- a/lightllm/models/qwen3/model.py +++ b/lightllm/models/qwen3/model.py @@ -1,5 +1,3 @@ -import torch -from typing import final from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 10a734e5c3..c85c423c29 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -19,7 +19,7 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 @@ -28,7 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]): ) self.num_experts_per_tok = network_config["num_experts_per_tok"] self.norm_topk_prob = network_config["norm_topk_prob"] - super().__init__(layer_num, network_config, mode) + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) self.tp_v_head_num_ = max(self.tp_v_head_num_, 1) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 72721f9d6f..486f4d6966 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -4,14 +4,14 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): self.n_routed_experts = network_config["num_experts"] self.is_moe = ( network_config["num_experts"] > 0 and layer_num not in network_config["mlp_only_layers"] and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 ) - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py index d219173401..4e2b65d743 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe_mtp/layer_infer/transformer_layer_infer.py @@ -14,8 +14,8 @@ class Qwen3MOEMTPTransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) return def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 6cc447a594..8ba95c1386 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -9,8 +9,8 @@ class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.proj.weight", diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 22d4d19505..095afecd91 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight(self): diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index 96e453ebe7..c24166e13d 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -8,8 +8,8 @@ class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward( diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 17ce4b7693..d1c51365a1 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -1,20 +1,12 @@ import torch -import torch.functional as F import torch.distributed as dist -import numpy as np -from functools import partial from typing import Tuple from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused -from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward @@ -22,8 +14,8 @@ class Qwen3VLTransformerLayerInfer(Qwen2VLTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" diff --git a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py index 5d41d85515..8a380853de 100644 --- a/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py @@ -12,8 +12,8 @@ def rename_weight_keys(weights): class Qwen3VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/qwen3_vl/model.py b/lightllm/models/qwen3_vl/model.py index 0d8a81f671..74aa33e3c0 100644 --- a/lightllm/models/qwen3_vl/model.py +++ b/lightllm/models/qwen3_vl/model.py @@ -37,9 +37,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index b155f8b907..328cc0a625 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -7,13 +7,12 @@ from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce -from lightllm.utils.dist_utils import get_global_world_size from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.mrope_section = torch.tensor( network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index b1f5ee6600..52a982f495 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py index f4eef6e698..48ddf52089 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -4,8 +4,8 @@ class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) def load_hf_weights(self, weights): moe_prefix = f"model.layers.{self.layer_num_}.mlp.experts" diff --git a/lightllm/models/qwen3_vl_moe/model.py b/lightllm/models/qwen3_vl_moe/model.py index b11f22fdb7..cc1201de2c 100644 --- a/lightllm/models/qwen3_vl_moe/model.py +++ b/lightllm/models/qwen3_vl_moe/model.py @@ -25,9 +25,6 @@ def __init__(self, kvargs): super().__init__(kvargs) return - def _init_inferstate_cls(self): - pass - def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index f439073077..939843a3eb 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -24,8 +24,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index edebccf17f..d942d68497 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -1,5 +1,3 @@ -import json -import numpy as np import unicodedata from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.server.core.objs import SamplingParams diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 395ed4ba1a..f908dbdd3b 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -8,8 +8,8 @@ class StablelmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.partial_rotary_factor = self.network_config_.get("partial_rotary_factor", 1) return diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 0ad3e07df5..3d044eeb56 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -2,8 +2,8 @@ class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.final_norm_weight_ = NoTpNormWeight( weight_name="model.norm.weight", data_type=self.data_type_, diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index a1a73f6745..03ee50feb5 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -2,8 +2,8 @@ class StablelmTransformerLayerWeight(Qwen2TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _init_weight_names(self): diff --git a/lightllm/models/stablelm/model.py b/lightllm/models/stablelm/model.py index 2ed710fd4c..a3d295358f 100644 --- a/lightllm/models/stablelm/model.py +++ b/lightllm/models/stablelm/model.py @@ -1,6 +1,3 @@ -import os -import json -import torch from lightllm.models.registry import ModelRegistry from lightllm.models.stablelm.layer_infer.transformer_layer_infer import StablelmTransformerLayerInfer from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index 52072a3487..6b88c066ee 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -9,8 +9,8 @@ class StarcoderPreLayerInfer(PreLayerInfer): """ """ - def __init__(self, network_config, mode): - super().__init__(network_config, mode) + def __init__(self, network_config): + super().__init__(network_config) self.layer_norm_eps_ = network_config["layer_norm_epsilon"] def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): diff --git a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py index 018816fcc6..074f3411a7 100644 --- a/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/transformer_layer_infer.py @@ -1,17 +1,19 @@ from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from functools import partial class StarcoderTransformerLayerInfer(BloomTransformerLayerInfer): """ """ - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) self.tp_k_head_num_ = 1 self.tp_v_head_num_ = 1 self._bind_func() return def _bind_func(self): - LlamaTransformerLayerInfer._bind_attention(self) + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_attention_kernel, self) return diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index d5bdd79a7b..329a0245f0 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -8,8 +8,8 @@ class StarcoderPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 2aa9dd9ef2..41f24f79cb 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class StarcoderTransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg, layer_prefix="transformer.h") + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg, layer_prefix="transformer.h") assert network_config["num_attention_heads"] % self.tp_world_size_ == 0 def load_hf_weights(self, weights): diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 796a96bc4a..09e3299eb6 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -5,8 +5,8 @@ class Starcoder2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, network_config, mode=[]): - super().__init__(layer_num, network_config, mode) + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 28a26cb4b3..6ee1885372 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -4,8 +4,8 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index 6314fa0e57..53342e221f 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -3,8 +3,8 @@ class Starcoder2TransformerLayerWeight(LlamaTransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _parse_config(self): diff --git a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py index b24fc0f0d1..44e18c2826 100644 --- a/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/tarsier2/layer_weights/pre_and_post_layer_weight.py @@ -17,8 +17,8 @@ def rename_weight_keys(weights): class Tarsier2Qwen2PreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): @@ -28,8 +28,8 @@ def load_hf_weights(self, weights): class Tarsier2LlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) return def load_hf_weights(self, weights): diff --git a/lightllm/models/vit/layer_infer/post_layer_infer.py b/lightllm/models/vit/layer_infer/post_layer_infer.py index 613aec3fa7..fa4a87f158 100644 --- a/lightllm/models/vit/layer_infer/post_layer_infer.py +++ b/lightllm/models/vit/layer_infer/post_layer_infer.py @@ -9,11 +9,10 @@ class ViTPostLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode self.llm_hidden_size = network_config["llm_hidden_size"] self.downsample_ratio = network_config["downsample_ratio"] return diff --git a/lightllm/models/vit/layer_infer/pre_layer_infer.py b/lightllm/models/vit/layer_infer/pre_layer_infer.py index 896e8e898c..306bf9f0e6 100644 --- a/lightllm/models/vit/layer_infer/pre_layer_infer.py +++ b/lightllm/models/vit/layer_infer/pre_layer_infer.py @@ -11,11 +11,10 @@ class ViTPreLayerInfer: """ """ - def __init__(self, network_config, mode): + def __init__(self, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.network_config_ = network_config - self.mode = mode return def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0b89dca11e..0d55d1b57f 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -13,7 +13,7 @@ class ViTTransformerLayerInfer: """ """ - def __init__(self, layer_num, network_config, mode=[]): + def __init__(self, layer_num, network_config): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.eps_ = network_config["layer_norm_eps"] @@ -25,7 +25,6 @@ def __init__(self, layer_num, network_config, mode=[]): self.tp_padding_embed_dim_ = self.tp_padding_head_num * self.head_dim_ self.network_config_ = network_config - self.mode = mode self.layer_num_ = layer_num return diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 276d4e5d0b..e2bed10361 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -7,8 +7,8 @@ class ViTPreAndPostLayerWeight(PreAndPostLayerWeight): - def __init__(self, data_type, network_config, mode): - super().__init__(data_type, network_config, mode) + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) self.embed_dim = self.network_config_["hidden_size"] self.image_size = self.network_config_["image_size"] self.patch_size = self.network_config_["patch_size"] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index c6024594e3..dffcc16fe8 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -14,8 +14,8 @@ class ViTTransformerLayerWeight(TransformerLayerWeight): - def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): - super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) return def _cuda(self, cpu_tensor): diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index b8e6eaf929..9c2bc42426 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -40,7 +40,6 @@ def __init__(self, kvargs): self.tp_world_size_ = get_dp_world_size() self.weight_dir_ = kvargs["weight_dir"] self.load_way = kvargs.get("load_way", "HF") - self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] self.weight_dict = kvargs.get("weight_dict", None) self.data_type = kvargs.get("data_type", "float16") self.quant_type = kvargs.get("quant_type", None) @@ -112,15 +111,12 @@ def _padding_hidden_size(self): return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class( - self.data_type, network_config=self.config, mode=self.mode - ) + self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( i, self.data_type, network_config=self.config, - mode=self.mode, quant_cfg=self.quant_cfg, ) for i in range(self.config["num_hidden_layers"]) @@ -141,10 +137,10 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_infer_layer(self): - self.pre_infer = self.pre_layer_infer_class(network_config=self.config, mode=self.mode) - self.post_infer = self.post_layer_infer_class(network_config=self.config, mode=self.mode) + self.pre_infer = self.pre_layer_infer_class(network_config=self.config) + self.post_infer = self.post_layer_infer_class(network_config=self.config) self.layers_infer = [ - self.transformer_layer_infer_class(i, network_config=self.config, mode=self.mode) + self.transformer_layer_infer_class(i, network_config=self.config) for i in range(self.config["num_hidden_layers"]) ] return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d193bab41c..16c844d483 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -219,29 +219,6 @@ def make_argument_parser() -> argparse.ArgumentParser: the --nccl_host must equal to the config_server_host, and the --nccl_port must be unique for a config_server, dont use same nccl_port for different inference node, it will be critical error""", ) - - parser.add_argument( - "--mode", - type=str, - default=[], - nargs="+", - help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_int8kv_flashdecoding | ppl_int8kv_flashdecoding_diverse - | ppl_fp16 | triton_flashdecoding - | triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv - | export_fp8kv_calibration - triton_flashdecoding mode is for long context, current support llama llama2 qwen; - triton_gqa_attention and triton_gqa_flashdecoding is fast kernel for model which use GQA; - triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; - triton_fp8kv mode use float8 to store kv cache, currently only for deepseek2; - offline_calibration_fp8kv mode use float8 to store kv cache, need fa3 or flashinfer backend, - currently only for llama and qwen model; - export_fp8kv_calibration record and export kv cache quant calibration results to a json file. - It can be used for llama and qwen model. - Calibration need to disable cudagraph and use fa3 or flashinfer backend. - ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; - ppl_fp16 mode use ppl fast fp16 decode attention kernel; - you need to read source code to make sure the supported detail mode for all models""", - ) parser.add_argument( "--trust_remote_code", action="store_true", @@ -337,21 +314,40 @@ def make_argument_parser() -> argparse.ArgumentParser: only deepseekv3 model supported now.""", ) parser.add_argument( - "--enable_flashinfer_prefill", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for prefill, - only deepseekv3 model supported now.""", + "--llm_prefill_att_backend", + type=str, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["triton"], + help="""prefill attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( - "--enable_flashinfer_decode", - action="store_true", - help="""inference backend will use the attention kernel of flashinfer for decode, - only deepseekv3 model supported now.""", + "--llm_decode_att_backend", + type=str, + nargs="+", + choices=["None", "triton", "fa3", "flashinfer"], + default=["triton"], + help="""decode attention kernel used in llm. + None: automatically select backend based on current GPU device, + not supported yet, will support in future""", ) parser.add_argument( - "--enable_fa3", - action="store_true", - help="""inference backend will use the fa3 attention kernel for prefill and decode""", + "--llm_kv_type", + type=str, + choices=["None", "int8kv", "int4kv"], + default="None", + help="""kv type used in llm, None for dtype that llm used in config.json. + fp8kv: not fully supported yet, will support in future""", + ) + parser.add_argument( + "--llm_kv_quant_group_size", + type=int, + default=8, + help="""kv quant group size used in llm kv, when llm_kv_type is quanted type,such as int8kv, + this params will be effective. + """, ) parser.add_argument( "--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources" diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 4ead3cbbf7..3ae3789f47 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -122,21 +122,6 @@ def normal_or_p_d_start(args): if args.return_all_prompt_logprobs: assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" - if "offline_calibration_fp8kv" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "offline_calibration_fp8kv mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - if "export_fp8kv_calibration" in args.mode: - assert args.enable_fa3 is True or ( - args.enable_flashinfer_prefill is True and args.enable_flashinfer_decode is True - ), ( - "export_fp8kv_calibration mode need enable fa3 or flashinfer, add --enable_fa3 or " - "--enable_flashinfer_prefill and --enable_flashinfer_decode" - ) - assert args.disable_cudagraph is True, "export_fp8kv_calibration mode need disable cudagraph" # 部分模式还不能支持与高级动态调度算法协同,to do. if args.diverse_mode: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5ebadaf165..7b440718f5 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -63,7 +63,6 @@ class StartArgs: nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -116,8 +115,14 @@ class StartArgs: quant_cfg: Optional[str] = field(default=None) vit_quant_type: Optional[str] = field(default=None) vit_quant_cfg: Optional[str] = field(default=None) - enable_flashinfer_prefill: bool = field(default=False) - enable_flashinfer_decode: bool = field(default=False) + llm_prefill_att_backend: List[str] = field( + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_decode_att_backend: List[str] = field( + default=("None",), metadata={"choices": ["None", "triton", "fa3", "flashinfer"]} + ) + llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) + llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( default="gpu_counter", metadata={"choices": ["cpu_counter", "pin_mem_counter", "gpu_counter"]} @@ -153,6 +158,3 @@ class StartArgs: # multi_modal enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) - - # kernel setting - enable_fa3: bool = field(default=False) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 89c46d9ed9..ac5c1abee3 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -58,7 +58,6 @@ def __init__(self, args: StartArgs): # 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐 self.is_safe_schedule = args.router_token_ratio == 0.0 self.load_way = args.load_way - self.mode = args.mode self.max_total_token_num = args.max_total_token_num self.shm_req_manager = ShmReqManager() # 用共享内存进行共享,router 模块读取进行精确的调度估计 @@ -155,7 +154,6 @@ async def wait_to_model_ready(self): "weight_dir": self.model_weightdir, "load_way": self.load_way, "max_total_token_num": self.max_total_token_num, - "mode": self.mode, "max_req_num": self.args.running_max_req_size + 8, "max_seq_length": self.args.max_req_total_len + 8, # 留一点余量 "nccl_host": self.args.nccl_host, diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 92653bc0cd..805c9b8e50 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -88,7 +88,6 @@ def init_model(self, kvargs): # dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.dp_size_in_node = max(1, self.dp_size // self.nnodes) self.load_way = kvargs["load_way"] - self.mode = kvargs["mode"] self.disable_chunked_prefill = self.args.disable_chunked_prefill self.chunked_prefill_size = self.args.chunked_prefill_size self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs @@ -148,7 +147,6 @@ def init_model(self, kvargs): "weight_dir": self.weight_dir, "max_total_token_num": max_total_token_num, "load_way": self.load_way, - "mode": self.mode, "max_req_num": kvargs.get("max_req_num", 1000), "max_seq_length": kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": kvargs.get("is_token_healing", False), @@ -302,7 +300,6 @@ def init_mtp_draft_model(self, main_kvargs: dict): "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, "load_way": main_kvargs["load_way"], - "mode": main_kvargs["mode"], "max_req_num": main_kvargs.get("max_req_num", 1000), "max_seq_length": main_kvargs.get("max_seq_length", 1024 * 5), "is_token_healing": False, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index f3450261b7..a8a5224ebc 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -399,7 +399,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] draft_model_input.mem_indexes = torch.cat( [draft_model_input.mem_indexes.view(-1, self.mtp_step + 1)[:, 1:], eagle_mem_indexes_i.view(-1, 1)], diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index df10a6d4e6..bb0e848e76 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -591,7 +591,7 @@ def _draft_decode_eagle( draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) # update the meta info of the inference draft_model_input.b_seq_len += 1 - draft_model_input.max_len_in_batch += 1 + draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * real_req_num : (_step + 1) * real_req_num] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -955,7 +955,7 @@ def _draft_decode_eagle_overlap( ) draft_model_input0.b_seq_len += 1 - draft_model_input0.max_len_in_batch += 1 + draft_model_input0.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes0[_step * real_req_num0 : (_step + 1) * real_req_num0] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, @@ -969,7 +969,7 @@ def _draft_decode_eagle_overlap( ).view(-1) draft_model_input1.b_seq_len += 1 - draft_model_input1.max_len_in_batch += 1 + draft_model_input1.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes1[_step * real_req_num1 : (_step + 1) * real_req_num1] eagle_mem_indexes_i = F.pad( input=eagle_mem_indexes_i, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 6465995c45..03ac4cfb05 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -73,7 +73,6 @@ def padded_prepare_prefill_inputs( max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu") @@ -102,7 +101,6 @@ def padded_prepare_prefill_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -150,6 +148,7 @@ def padded_prepare_decode_inputs( seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1 b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) @@ -160,32 +159,30 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_req_idx.append(req.req_idx) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(step + 1) batch_multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) - # padding fake req for decode for _ in range(padded_req_num): seq_len = 2 total_token_num += seq_len b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_mtp_index.append(0) batch_multimodal_params.append({"images": [], "audios": []}) for step in range(args_mtp_step): seq_len += 1 total_token_num += seq_len b_seq_len.append(seq_len) + b_q_seq_len.append(1) b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID) b_mtp_index.append(step + 1) batch_multimodal_params.append({"images": [], "audios": []}) - b_q_seq_len.append(1 + args_mtp_step) - max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) - max_len_in_batch = max(b_seq_len) b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") @@ -210,7 +207,6 @@ def padded_prepare_decode_inputs( model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index bdb36054b4..4eb8c7e1e6 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -50,7 +50,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> max_kv_seq_len = max(b_seq_len) max_cache_len = max(b_ready_cache_len) - max_len_in_batch = max(b_q_seq_len) max_q_seq_len = max(b_q_seq_len) input_ids = np.concatenate(input_ids, dtype=np.int64) @@ -72,7 +71,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, max_cache_len=max_cache_len, @@ -95,7 +93,6 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: run_reqs: List[InferReq] = [] total_token_num = 0 - max_len_in_batch = 0 b_req_idx = [] b_mtp_index = [] b_seq_len = [] @@ -107,8 +104,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len = req.get_cur_total_len() assert req.cur_kv_len == seq_len - 1, f"{req.cur_kv_len} {seq_len}" b_seq_len.append(seq_len) + b_q_seq_len.append(1) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. @@ -118,10 +115,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In seq_len += 1 b_seq_len.append(seq_len) total_token_num += seq_len - max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) - b_q_seq_len.append(req.mtp_step + 1) + b_q_seq_len.append(1) max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) @@ -146,7 +142,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In model_input = ModelInput( batch_size=b_seq_len.shape[0], total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, input_ids=None, diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b5822a342c..0a70f1dfa6 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,10 +26,16 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + + from lightllm.server.core.objs.start_args_type import StartArgs + + args: StartArgs = args + + if "flashinfer" in args.llm_prefill_att_backend or "flashinfer" in args.llm_decode_att_backend: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + return def set_env_start_args(args): @@ -209,7 +215,7 @@ def get_diverse_max_batch_shared_group_size() -> int: @lru_cache(maxsize=None) def enable_diverse_mode_gqa_decode_fast_kernel() -> bool: - return get_env_start_args().diverse_mode and "ppl_int8kv_flashdecoding_diverse" in get_env_start_args().mode + return get_env_start_args().diverse_mode and "int8kv" == get_env_start_args().llm_kv_type @lru_cache(maxsize=None) diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 4875b4eee6..3256fdd1fd 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -19,13 +19,11 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.kv_cache_mem_manager import ( MemoryManager, - INT8KVMemoryManager, CalibrationFP8KVMemoryManager, ExportCalibrationMemoryManager, PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, Deepseek2MemoryManager, - Deepseek2FP8KVMemoryManager, ) from typing import List, Tuple, Optional @@ -77,17 +75,6 @@ def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": scale_head_dim=0, scale_data_type=get_llm_data_type(), ) - elif mem_manager_class is Deepseek2FP8KVMemoryManager: - cpu_cache_meta = CpuKVCacheMeta( - page_num=0, - token_page_size=args.cpu_cache_token_page_size, - layer_num=get_layer_num(args.model_dir), - num_heads=1, - head_dim=512 + 64 + 2, - data_type=torch.uint8, - scale_head_dim=0, - scale_data_type=get_llm_data_type(), - ) elif mem_manager_class is MemoryManager: cpu_cache_meta = CpuKVCacheMeta( page_num=0, diff --git a/test/acc/test_deepseekr1.sh b/test/acc/test_deepseekr1.sh index e167303a35..180d2d4e20 100644 --- a/test/acc/test_deepseekr1.sh +++ b/test/acc/test_deepseekr1.sh @@ -1,4 +1,4 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --enable_fa3 +LOADWORKER=18 python -m lightllm.server.api_server --batch_max_tokens 6000 --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 diff --git a/test/acc/test_deepseekr1_mtp.sh b/test/acc/test_deepseekr1_mtp.sh index 046314a728..7eaffd4993 100644 --- a/test/acc/test_deepseekr1_mtp.sh +++ b/test/acc/test_deepseekr1_mtp.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --port 8089 --mem_fraction 0.75 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_deepseekr1_mtp_ep.sh b/test/acc/test_deepseekr1_mtp_ep.sh index 2ea5f74387..0467f76e6a 100644 --- a/test/acc/test_deepseekr1_mtp_ep.sh +++ b/test/acc/test_deepseekr1_mtp_ep.sh @@ -1,3 +1,3 @@ -LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --enable_fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 +LOADWORKER=18 MOE_MODE=EP NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 python -m lightllm.server.api_server --model_dir /mtc/models/DeepSeek-R1 --tp 8 --dp 8 --port 8089 --max_total_token_num 60000 --graph_max_batch_size 16 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --batch_max_tokens 6000 --mtp_mode eagle_with_att --mtp_draft_model_dir /mtc/models/DeepSeek-R1-NextN --mtp_step 2 HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"deepseek-ai/DeepSeek-R1", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 32 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen2.sh b/test/acc/test_qwen2.sh index 265d679e8a..bb5603b5be 100644 --- a/test/acc/test_qwen2.sh +++ b/test/acc/test_qwen2.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"Qwen/Qwen2.5-Math-7B-Instruct", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/acc/test_qwen3.sh b/test/acc/test_qwen3.sh index c0da5ec96e..36a3c96804 100644 --- a/test/acc/test_qwen3.sh +++ b/test/acc/test_qwen3.sh @@ -1,5 +1,5 @@ # first -LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --enable_fa3 +LOADWORKER=18 CUDA_VISIBLE_DEVICES=6,7 python -m lightllm.server.api_server --model_dir /mtc/models/qwen3-8b --tp 2 --port 8089 --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # second HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions --model_args '{"model":"qwen/qwen3-8b", "base_url":"http://localhost:8089/v1/completions", "max_length": 16384}' --tasks gsm8k --batch_size 500 --confirm_run_unsafe_code \ No newline at end of file diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 3fc7ee4b45..a8abd2ae64 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -73,7 +73,6 @@ def overlap_prefill( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: total_token_num // 2] _0_mem_indexes = mem_indexes[: total_token_num // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -83,7 +82,6 @@ def overlap_prefill( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -96,7 +94,6 @@ def overlap_prefill( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[total_token_num // 2 :] _1_mem_indexes = mem_indexes[total_token_num // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -107,7 +104,6 @@ def overlap_prefill( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -129,7 +125,6 @@ def overlap_decode( ): _0_batch_size = batch_size // 2 _0_total_token_num = total_token_num // 2 - _0_max_len_in_batch = max_len_in_batch _0_input_ids = input_ids[: batch_size // 2] _0_mem_indexes = mem_indexes[: batch_size // 2] _0_b_req_idx = b_req_idx[: batch_size // 2] @@ -138,7 +133,6 @@ def overlap_decode( micro_batch1 = ModelInput( batch_size=_0_batch_size, total_token_num=_0_total_token_num, - max_len_in_batch=_0_max_len_in_batch, input_ids=_0_input_ids, b_req_idx=_0_b_req_idx, b_mtp_index=_0_b_mtp_index, @@ -149,7 +143,6 @@ def overlap_decode( _1_batch_size = batch_size - batch_size // 2 _1_total_token_num = total_token_num - total_token_num // 2 - _1_max_len_in_batch = max_len_in_batch _1_input_ids = input_ids[batch_size // 2 :] _1_mem_indexes = mem_indexes[batch_size // 2 :] _1_b_req_idx = b_req_idx[batch_size // 2 :] @@ -159,7 +152,6 @@ def overlap_decode( micro_batch2 = ModelInput( batch_size=_1_batch_size, total_token_num=_1_total_token_num, - max_len_in_batch=_1_max_len_in_batch, input_ids=_1_input_ids, b_req_idx=_1_b_req_idx, b_mtp_index=_1_b_mtp_index, @@ -191,7 +183,6 @@ def prefill( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=max_len_in_batch, max_kv_seq_len=max_len_in_batch, max_cache_len=0, @@ -217,7 +208,6 @@ def decode( model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=max_len_in_batch, max_q_seq_len=1, max_kv_seq_len=max_len_in_batch, input_ids=input_ids, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index 942af0f883..07ad52a132 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -129,7 +129,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, - max_len_in_batch=input_len, input_ids=test_data, mem_indexes=mem_indexes, b_req_idx=b_req_idx, @@ -197,7 +196,6 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), total_token_num=nopad_total_token_num, - max_len_in_batch=nopad_max_len_in_batch, input_ids=decode_input_ids, mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index f5dae19b92..e00af27139 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -108,7 +108,6 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree -- `--enable_fa3`: Enable Flash Attention 3.0 - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/draft.sh b/test/start_scripts/draft.sh index 866f5f2fa5..04f573cd6d 100644 --- a/test/start_scripts/draft.sh +++ b/test/start_scripts/draft.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 @@ -16,7 +16,7 @@ HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 lm_eval --model local-completions \ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index 3a139968a6..68f80b39d5 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index b24a598688..10aee85285 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -6,7 +6,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node0.sh b/test/start_scripts/multi_node_tp_node0.sh index b86bdeb358..d750da93ca 100644 --- a/test/start_scripts/multi_node_tp_node0.sh +++ b/test/start_scripts/multi_node_tp_node0.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_node_tp_node1.sh b/test/start_scripts/multi_node_tp_node1.sh index 378977ab2e..cb495496e8 100644 --- a/test/start_scripts/multi_node_tp_node1.sh +++ b/test/start_scripts/multi_node_tp_node1.sh @@ -5,7 +5,7 @@ export nccl_host=$1 LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ diff --git a/test/start_scripts/multi_pd_master/pd_decode.sh b/test/start_scripts/multi_pd_master/pd_decode.sh index 4cefef6fb2..2b7bb80d76 100644 --- a/test/start_scripts/multi_pd_master/pd_decode.sh +++ b/test/start_scripts/multi_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --nccl_port 12322 \ --tp 8 \ --dp 8 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 --config_server_host $config_server_host \ --config_server_port 60088 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index b845da435d..eaa343ef62 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --tp 8 \ --dp 8 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ --config_server_port 60088 diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index cad172d515..7406d94628 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -3,7 +3,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp.sh b/test/start_scripts/single_node_tp.sh index 1fb461bb11..ee10b6c101 100644 --- a/test/start_scripts/single_node_tp.sh +++ b/test/start_scripts/single_node_tp.sh @@ -2,7 +2,7 @@ LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ ---enable_fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_node_tp_cpu_cache_enable.sh b/test/start_scripts/single_node_tp_cpu_cache_enable.sh index 3caabb59bd..47da83dbe9 100644 --- a/test/start_scripts/single_node_tp_cpu_cache_enable.sh +++ b/test/start_scripts/single_node_tp_cpu_cache_enable.sh @@ -3,7 +3,7 @@ LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /mtc/models/qwen3-8b --tp 2 --dp 1 --enable_cpu_cache --cpu_cache_storage_size 66 --cpu_cache_token_page_size 128 \ --batch_max_tokens 4096 --chunked_prefill_size 2048 \ --max_total_token_num 20000 \ ---mode "ppl_int8kv_flashdecoding" | tee log.txt +--llm_kv_type int8kv | tee log.txt # 精度评测命令 diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index ae16b96ad4..36804dd11e 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 1b43c11cc4..5fb34a973e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -18,7 +18,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8121 \ --nccl_port 12322 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index 303de29758..5a37df0b1d 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -19,7 +19,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index f6e2e4b685..b94a1f8ccd 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -13,7 +13,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --host $host \ --port 8019 \ --nccl_port 2732 \ ---enable_fa3 \ +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/test_api/test_generate_api.py b/test/test_api/test_generate_api.py index 05fbda44ea..4ea74b7f6d 100644 --- a/test/test_api/test_generate_api.py +++ b/test/test_api/test_generate_api.py @@ -19,7 +19,7 @@ def run(self): print("Error:", response.status_code, response.text) -url = "http://localhost:8000/generate" +url = "http://localhost:8089/generate" headers = {"Content-Type": "application/json"} for i in range(1): diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py new file mode 100644 index 0000000000..d1a53f873f --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad.py @@ -0,0 +1,104 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( + context_attention_fwd_contiguous_kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_context_attention_fwd2(q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len): + batch = b_start_loc.shape[0] + + for i in range(batch): + start_loc = b_start_loc[i] + kv_start_loc = b_kv_start_loc[i] + seq_len = b_seq_len[i] + prompt_cache_len = b_prompt_cache_len[i] + cur_q = q[start_loc : start_loc + seq_len - prompt_cache_len, :, :] + cur_q = cur_q.clone().to(torch.float32) + cur_k = k[kv_start_loc : (kv_start_loc + seq_len), :, :] + cur_k = cur_k.clone().to(torch.float32) + + cur_v = v[kv_start_loc : (kv_start_loc + seq_len), :, :] + cur_v = cur_v.clone().to(torch.float32) + + dk = cur_q.shape[-1] + cur_q = cur_q.permute(1, 0, 2) + cur_k = cur_k.permute(1, 2, 0) + cur_v = cur_v.permute(1, 0, 2) + dk = cur_q.shape[-1] + + p = torch.matmul(cur_q, cur_k) / torch.sqrt(torch.tensor(dk, dtype=torch.float32)) + + q_index = (torch.arange(cur_q.shape[1]).to(p.device) + prompt_cache_len).view(-1, 1) + k_index = torch.arange(seq_len).to(p.device).view(1, -1) + + p[:, (q_index < k_index)] = float("-inf") + + s = torch.nn.functional.softmax(p, dim=-1) + + o[start_loc : start_loc + seq_len - prompt_cache_len, :, :] = torch.matmul(s, cur_v).transpose(0, 1) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, prompt_cache_len", + [ + (b, H, N_CTX, D_HEAD, prompt_cache_len) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for prompt_cache_len in [0, 56, 200] + ], +) +def test_context_attention_fwd_contiguous_kv(B, H, N_CTX, D_HEAD, prompt_cache_len): + dtype = torch.float16 + prompt_cache_len = 0 + if prompt_cache_len >= N_CTX - 1: + return + + q = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + k = kv[:, :H, :] + v = kv[:, H:, :] + + o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + torch_o = torch.empty((B * (N_CTX - prompt_cache_len), H, D_HEAD), dtype=dtype, device="cuda") + + max_q_input_len = N_CTX - prompt_cache_len + + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.zeros(B, dtype=torch.int32, device="cuda") + + for i in range(B): + b_seq_len[i] = N_CTX + if i != 0: + b_start_loc[i] = b_start_loc[i - 1] + N_CTX - prompt_cache_len + b_prompt_cache_len[i] = prompt_cache_len + + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + torch_context_attention_fwd2(q, k, v, torch_o, b_start_loc, b_kv_start_loc, b_seq_len, b_prompt_cache_len) + context_attention_fwd_contiguous_kv( + q=q, + k=k, + v=v, + o=o, + b_start_loc=b_start_loc, + b_kv_start_loc=b_kv_start_loc, + b_seq_len=b_seq_len, + max_q_input_len=max_q_input_len, + b_prompt_cache_len=b_prompt_cache_len, + ) + + assert torch.allclose(torch_o, o, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(o.flatten().float(), torch_o.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py new file mode 100644 index 0000000000..83537ec708 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int4kv_copy_kv.py @@ -0,0 +1,62 @@ +import torch +import pytest +import numpy as np +from typing import Tuple +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv, dequantize_int4kv +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def test_quanted_and_dequant(): + """Test quantization followed by dequantization.""" + batch_size = 1 + seq_len = 8 + head_num = 4 + k_head_num = 2 + v_head_num = 2 + assert k_head_num + v_head_num == head_num + head_dim = 64 + quant_group_size = 8 + + # Create original data + original_kv = torch.randn(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).clamp_(-1, 1).cuda() + dest_loc = torch.arange(batch_size * seq_len, dtype=torch.int64).cuda() + + # Quantize + group_count = head_dim // quant_group_size + kv_buffer = torch.zeros(batch_size * seq_len, head_num, head_dim // 2, dtype=torch.int8).cuda() + kv_scale_buffer = torch.zeros(batch_size * seq_len, head_num, group_count, dtype=torch.float32).cuda() + destindex_copy_int4kv(original_kv, dest_loc, kv_buffer, kv_scale_buffer, quant_group_size) + + # Dequantize + req_to_token_indexs = torch.arange(seq_len, dtype=torch.int64).view(1, -1).cuda() + b_seq_len = torch.tensor([seq_len], dtype=torch.int32).cuda() + b_req_idx = torch.tensor([0], dtype=torch.int32).cuda() + b_kv_start_loc = torch.tensor([0], dtype=torch.int32).cuda() + + recovered_kv = torch.zeros(batch_size * seq_len, head_num, head_dim, dtype=torch.float32).cuda() + + dequantize_int4kv( + k=kv_buffer[:, 0:k_head_num, :], + k_scale=kv_scale_buffer[:, 0:k_head_num, :], + v=kv_buffer[:, k_head_num:, :], + v_scale=kv_scale_buffer[:, k_head_num:, :], + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=recovered_kv[:, :k_head_num, :], + v_out=recovered_kv[:, k_head_num:, :], + max_len_in_batch=seq_len, + quant_group_size=quant_group_size, + ) + + logger.info("Round-trip test completed!") + assert torch.allclose(recovered_kv, original_kv, atol=2 / 14, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(recovered_kv.flatten().float(), original_kv.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py new file mode 100644 index 0000000000..149c9894ae --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_ppl_int8kv_copy_kv.py @@ -0,0 +1,86 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import ( + dequantize_int8kv, + destindex_copy_quantize_kv, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def torch_dequant(kv, kv_scale, b_req_idx, b_seq_len, req_to_token_indexs, odtype, group_quant_size): + batch = b_req_idx.shape[0] + tmp_out = [] + for i in range(batch): + req_idx = b_req_idx[i] + seq_len = b_seq_len[i] + kv_loc = req_to_token_indexs[req_idx, :seq_len] + head_num = kv.shape[1] + cur_kv = kv[kv_loc, :, :].reshape(seq_len, head_num, -1, group_quant_size).to(odtype) + cur_scale = kv_scale[kv_loc, :, :].reshape(seq_len, head_num, -1, 1) + out = cur_kv * cur_scale + tmp_out.append(out.reshape(seq_len, head_num, -1)) + return torch.cat(tmp_out, dim=0) + + +@pytest.mark.parametrize( + "B, H, N_CTX, D_HEAD, group_quant_size", + [ + (b, H, N_CTX, D_HEAD, group_quant_size) + for b in [1, 2, 4] + for H in [1, 8] + for N_CTX in [3, 10, 1024] + for D_HEAD in [64, 128] + for group_quant_size in [8, 16] + ], +) +def test_dequantize_int8kv(B, H, N_CTX, D_HEAD, group_quant_size): + dtype = torch.bfloat16 + kv = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=torch.int8, device="cuda").random_(-10, 10) + kv_scale = torch.randn((B * N_CTX, 2 * H, D_HEAD // group_quant_size), dtype=dtype, device="cuda") + out = torch.empty((B * N_CTX, 2 * H, D_HEAD), dtype=dtype, device="cuda") + req_to_token_indexs = torch.empty((B, N_CTX), dtype=torch.int32, device="cuda") + max_input_len = N_CTX + b_seq_len = torch.ones((B,), dtype=torch.int32, device="cuda") + b_seq_len.fill_(N_CTX) + b_req_idx = torch.arange(0, B, dtype=torch.int32, device="cuda") + req_to_token_indexs.view(-1)[:] = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") + b_kv_start_loc = torch.cumsum(b_seq_len, dim=0, dtype=torch.int32) - b_seq_len + + k = kv[:, :H, :] + v = kv[:, H:, :] + k_scale = kv_scale[:, :H, :] + v_scale = kv_scale[:, H:, :] + + ground_out = torch_dequant( + kv=kv, + kv_scale=kv_scale, + b_req_idx=b_req_idx, + b_seq_len=b_seq_len, + req_to_token_indexs=req_to_token_indexs, + odtype=out.dtype, + group_quant_size=group_quant_size, + ) + dequantize_int8kv( + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + req_to_token_indexs=req_to_token_indexs, + b_seq_len=b_seq_len, + b_req_idx=b_req_idx, + b_kv_start_loc=b_kv_start_loc, + k_out=out[:, :H, :], + v_out=out[:, H:, :], + max_len_in_batch=max_input_len, + quant_group_size=group_quant_size, + ) + assert torch.allclose(out, ground_out, atol=1e-2, rtol=0) + cos = torch.nn.CosineSimilarity(0) + assert cos(out.flatten().float(), ground_out.flatten().float()) > 0.99 + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py index 5c3ca89c65..41bc217b94 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_decode_params.py @@ -1,21 +1,9 @@ import torch import pytest -import easydict from lightllm.common.basemodel.triton_kernel.gen_decode_params import gen_decode_params -from lightllm.utils.envs_utils import set_env_start_args def test_gen_decode_params_basic(): - set_env_start_args( - easydict.EasyDict( - { - "mtp_step": 0, - "enable_flashinfer_prefill": False, - "enable_flashinfer_decode": False, - } - ) - ) - b_seq_len = torch.ones((9,), dtype=torch.int64, device="cuda") * 8192 ( b_q_seq_len, diff --git a/unit_tests/server/core/objs/test_shm_req_manager.py b/unit_tests/server/core/objs/test_shm_req_manager.py index 1d1ae2ef1a..dea40a4859 100644 --- a/unit_tests/server/core/objs/test_shm_req_manager.py +++ b/unit_tests/server/core/objs/test_shm_req_manager.py @@ -14,8 +14,6 @@ def setup_env(): running_max_req_size=10, disable_chunked_prefill=True, token_healing_mode=False, - enable_flashinfer_prefill=False, - enable_flashinfer_decode=False, ) ) # clear the lru_cache if used From 5d2f630d093911c0781edffdd6d6800941467f83 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Sat, 10 Jan 2026 13:14:32 +0800 Subject: [PATCH 05/43] fix unit test (#1173) --- .../test_ppl_int8kv_flash_decoding_diverse.py | 20 +- ...pl_int8kv_flash_decoding_diverse_stage1.py | 6 +- ...pl_int8kv_flash_decoding_diverse_stage2.py | 5 +- ...pl_int8kv_flash_decoding_diverse_stage3.py | 9 +- .../test_context_flashattention_nopad1.py} | 27 +-- .../kv_copy/test_mla_destindex_copy_kv.py} | 2 +- .../decode_att}/test_gqa_flash_decoding.py | 11 +- .../triton_kernel/test_atomic_event.py | 8 +- .../triton_kernel/test_gen_sampling_params.py | 1 + .../triton_kernel}/test_repack_kv_index.py | 2 +- unit_tests/common/fused_moe/test_deepep.py | 9 +- .../test_moe_silu_and_mul_mix_quant_ep.py | 14 +- .../common/fused_moe/test_softmax_topk.py | 5 +- .../test_fp8_scaled_mm_per_token.py | 12 ++ .../deepseek2/test_destindex_copy_kv_fp8.py | 48 ----- .../deepseek2/test_gqa_flash_decoding_fp8.py | 79 -------- ...st_context_flashattention_nopad_fa3_fp8.py | 154 --------------- ...ext_flashattention_nopad_flashinfer_fp8.py | 145 -------------- .../llama/test_token_attention_nopad.py | 151 -------------- .../test_token_attention_nopad_fa3_fp8.py | 187 ------------------ ...st_token_attention_nopad_flashinfer_fp8.py | 170 ---------------- .../models/qwen3-vl/test_deepstack_emb.py | 2 +- unit_tests/server/core/objs/test_req.py | 18 +- .../server/core/objs/test_shm_req_manager.py | 5 + 24 files changed, 109 insertions(+), 981 deletions(-) rename unit_tests/{models/llama => common/basemodel/triton_kernel/att/decode_att/int8kv}/test_ppl_int8kv_flash_decoding_diverse.py (91%) rename unit_tests/{models/llama => common/basemodel/triton_kernel/att/decode_att/int8kv}/test_ppl_int8kv_flash_decoding_diverse_stage1.py (93%) rename unit_tests/{models/llama => common/basemodel/triton_kernel/att/decode_att/int8kv}/test_ppl_int8kv_flash_decoding_diverse_stage2.py (96%) rename unit_tests/{models/llama => common/basemodel/triton_kernel/att/decode_att/int8kv}/test_ppl_int8kv_flash_decoding_diverse_stage3.py (79%) rename unit_tests/{models/llama/test_context_flashattention_nopad.py => common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py} (92%) rename unit_tests/{models/deepseek2/test_destindex_copy_kv.py => common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py} (93%) rename unit_tests/{models/deepseek2 => common/basemodel/triton_kernel/mla_att/decode_att}/test_gqa_flash_decoding.py (92%) rename unit_tests/{models/deepseek2 => common/basemodel/triton_kernel}/test_repack_kv_index.py (96%) delete mode 100644 unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py delete mode 100644 unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py delete mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py delete mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py delete mode 100644 unit_tests/models/llama/test_token_attention_nopad.py delete mode 100644 unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py delete mode 100644 unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py similarity index 91% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py index 67463fa7ad..ac18ffb955 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py @@ -1,4 +1,7 @@ import pytest + +pytest.skip(reason="need install lightllmKernel", allow_module_level=True) + import torch from lightllm.utils.light_utils import light_ops @@ -21,7 +24,7 @@ class MockInferState: def __init__( self, batch_size, - max_len_in_batch, + max_kv_seq_len, req_to_tokens, b_req_idx, b_seq_len, @@ -29,7 +32,7 @@ def __init__( b_mark_shared_group=None, ): self.batch_size = batch_size - self.max_len_in_batch = max_len_in_batch + self.max_kv_seq_len = max_kv_seq_len self.req_manager = MockReqManager(req_to_tokens) self.b_req_idx = b_req_idx self.b_seq_len = b_seq_len @@ -44,10 +47,11 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding 与 ppl_int8kv_flash_decoding (baseline) 的对比。 """ - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import ( + + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) - from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding as baseline_attention, ) @@ -87,7 +91,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) baseline_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=seq_len, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -96,7 +100,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 diverse 的 infer_state diverse_infer_state = MockInferState( batch_size=batch_size, - max_len_in_batch=seq_len, + max_kv_seq_len=seq_len, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -108,8 +112,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le baseline_out = baseline_attention( q=q.clone(), infer_state=baseline_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, @@ -120,8 +122,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le diverse_out = diverse_attention( q=q.clone(), infer_state=diverse_infer_state, - q_head_num=num_heads, - head_dim=head_dim, cache_k=cache_k, cache_k_scale=cache_k_scale, cache_v=cache_v, diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py similarity index 93% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py index 30e83b88b6..5ef36e38e2 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py @@ -1,6 +1,8 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage1 import ( + flash_decode_stage1, +) @pytest.fixture @@ -81,7 +83,7 @@ def test_flash_decode_stage1_execution(setup_tensors): new_k = k.to(q.dtype) new_v = v.to(q.dtype) - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( flash_decode_stage1 as gqa_flash_decode_stage1, ) diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py similarity index 96% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py index 2ba085cc91..cde7734817 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage2.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py @@ -1,4 +1,7 @@ import pytest + +pytest.skip(reason="need install lightllmkernel", allow_module_level=True) + import torch from lightllm.utils.light_utils import light_ops @@ -94,7 +97,7 @@ def test_flash_decode_stage2_execution(shared_seq_len): b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] - from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( flash_decode_stage1 as gqa_flash_decode_stage1, ) diff --git a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py similarity index 79% rename from unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py index b406e2dcf5..18550982b9 100644 --- a/unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse_stage3.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py @@ -1,6 +1,8 @@ import pytest import torch -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage3 import ( + flash_diverse_decode_stage3, +) @pytest.mark.parametrize( @@ -23,7 +25,10 @@ def test_flash_diverse_decode_stage3(batch, head_num, seq_len, shared_seq_len, b flash_diverse_decode_stage3(mid_out, mid_out_logexpsum, B_Seqlen, b_shared_seq_len, out, block_seq) true_out = torch.zeros_like(out) - from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2 + + from lightllm.common.basemodel.triton_kernel.att.decode_att.mha.flash_decoding.flash_decoding_stage2 import ( + flash_decode_stage2, + ) flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, true_out, block_seq) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py similarity index 92% rename from unit_tests/models/llama/test_context_flashattention_nopad.py rename to unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py index f24ab619bd..541594306d 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_flashattention_nopad1.py @@ -5,12 +5,11 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( +from lightllm.common.basemodel.triton_kernel.att.prefill_att.context_flashattention_nopad import ( context_attention_fwd, context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -54,14 +53,14 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc + infer_state.b_q_start_loc = q_start_loc context_attention_fwd( q, @@ -69,10 +68,10 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): kv[:, KV_HEADS:, :], o, infer_state.b_req_idx, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, ) @@ -127,7 +126,11 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): "batch, seqlen, q_heads, kv_heads, head_dim", [ (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] + for a in [ + 1, + 16, + 32, + ] for b in [16, 32, 512, 1024] for c in [28] for d in [4] @@ -149,18 +152,18 @@ def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, infer_state = LlamaInferStateInfo() infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX + infer_state.max_q_seq_len = N_CTX infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc + infer_state.b_q_start_loc = b_start_loc context_attention_fwd_no_prompt_cache( q, k, v, o, - infer_state.b_start_loc, + infer_state.b_q_start_loc, infer_state.b_seq_len, - infer_state.max_len_in_batch, + infer_state.max_q_seq_len, ) head_dim = HEAD_DIM diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv.py b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py similarity index 93% rename from unit_tests/models/deepseek2/test_destindex_copy_kv.py rename to unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py index 1379dc72de..ed0c6e369f 100644 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv.py +++ b/unit_tests/common/basemodel/triton_kernel/kv_copy/test_mla_destindex_copy_kv.py @@ -1,6 +1,6 @@ import torch import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.common.basemodel.triton_kernel.kv_copy.mla_copy_kv import destindex_copy_kv from lightllm.utils.log_utils import init_logger import torch.nn.functional as F diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py similarity index 92% rename from unit_tests/models/deepseek2/test_gqa_flash_decoding.py rename to unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py index d0bc670ecb..a5ac9708d2 100644 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding.py +++ b/unit_tests/common/basemodel/triton_kernel/mla_att/decode_att/test_gqa_flash_decoding.py @@ -5,9 +5,10 @@ import torch.nn.functional as F import flashinfer from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.common.basemodel.triton_kernel.mla_att.decode_att.gqa_flash_decoding import ( + gqa_token_decode_attention_flash_decoding, +) from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -53,7 +54,7 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) + infer_state.req_manager = type("Object", (), {})() infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len @@ -67,10 +68,6 @@ def test_gqa_flash_decoding(batch, seqlen, heads, nope_head, rope_head): kv_nope, kv_rope, infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, sm_scale, o, ) diff --git a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py index 536cad90fc..0afcd5558a 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py +++ b/unit_tests/common/basemodel/triton_kernel/test_atomic_event.py @@ -18,10 +18,10 @@ def test_add_in_place(): assert input.item() == 3, "最终值应为 3" -@pytest.mark.timeout(2) -def test_wait_timeout(): - input = torch.zeros((1,), device="cuda", dtype=torch.int32) - wait_value(input, 4) +# @pytest.mark.timeout(2) +# def test_wait_timeout(): +# input = torch.zeros((1,), device="cuda", dtype=torch.int32) +# wait_value(input, 4) if __name__ == "__main__": diff --git a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py index 99971dea25..e9d0193279 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py +++ b/unit_tests/common/basemodel/triton_kernel/test_gen_sampling_params.py @@ -25,6 +25,7 @@ def test_token_id_counter(): for _ in range(100): token_id_counter(prompt_ids=test_prompt_ids, out_token_id_counter=test_token_id_counter) end_event.record() + end_event.synchronize() logger.info(f"test_token_id_count cost time: {start_event.elapsed_time(end_event)} ms") diff --git a/unit_tests/models/deepseek2/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py similarity index 96% rename from unit_tests/models/deepseek2/test_repack_kv_index.py rename to unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index f9e5928a9e..b5184d3caa 100644 --- a/unit_tests/models/deepseek2/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,7 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index logger = init_logger(__name__) diff --git a/unit_tests/common/fused_moe/test_deepep.py b/unit_tests/common/fused_moe/test_deepep.py index c846be0961..45778244b7 100644 --- a/unit_tests/common/fused_moe/test_deepep.py +++ b/unit_tests/common/fused_moe/test_deepep.py @@ -1,12 +1,13 @@ +import pytest + +pytest.skip(reason="need special env, install deep_ep and deep_gemm", allow_module_level=True) + import os import torch import torch.distributed as dist -import pytest import deep_ep import random import numpy as np -from deep_ep import Buffer, EventOverlap -from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from typing import Tuple @@ -25,6 +26,8 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape + from deep_gemm import ceil_div + x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index eba15b2a1c..671805a3d2 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -1,6 +1,18 @@ import torch -import time import pytest + + +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip(reason="not support fp8 test in this gpu card", allow_module_level=True) + import random from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd diff --git a/unit_tests/common/fused_moe/test_softmax_topk.py b/unit_tests/common/fused_moe/test_softmax_topk.py index 262c37a0f2..6252dfa8c3 100755 --- a/unit_tests/common/fused_moe/test_softmax_topk.py +++ b/unit_tests/common/fused_moe/test_softmax_topk.py @@ -9,7 +9,10 @@ def benchmark(M, N, K, renorm, runs): - import sgl_kernel as sgl_ops + try: + import sgl_kernel as sgl_ops + except Exception as e: + pytest.skip(f"no sgl_kernel error: {str(e)}", allow_module_level=True) gating = torch.randn(M, N, device="cuda", dtype=torch.float32) torch.cuda.synchronize() diff --git a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py index 1ddb20b632..2c0b7bf76e 100644 --- a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py +++ b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py @@ -4,6 +4,18 @@ from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +def is_fp8_native_supported(): + """检查是否为 H100/B200 等原生支持 FP8 的硬件 (SM90+)""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 9 + + +if not is_fp8_native_supported(): + pytest.skip("not support fp8 in this gpu card", allow_module_level=True) + + @pytest.mark.parametrize("M", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("N,K", [(2048, 2048), (4096, 5120), (8192, 4096)]) @pytest.mark.parametrize("output_dtype", [torch.bfloat16]) diff --git a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py b/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py deleted file mode 100644 index 4f9c0a3373..0000000000 --- a/unit_tests/models/deepseek2/test_destindex_copy_kv_fp8.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import pytest -from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8 -from lightllm.utils.log_utils import init_logger -import torch.nn.functional as F - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head, copy_len", - [ - (a, b, c, d, e, f) - for a in [1, 16, 32, 128, 512] - for b in [1024, 2048] - for c in [1] - for d in [512] - for e in [64] - for f in [10, 20, 100, 1024] - ], -) -def test_destindex_copy_kv_fp8(batch, seqlen, heads, nope_head, rope_head, copy_len): - B, N_CTX, H, NOPE_HEAD, ROPE_HEAD, COPY_LEN = batch, seqlen, heads, nope_head, rope_head, copy_len - dtype = torch.bfloat16 - NUM = COPY_LEN - dest_loc = torch.arange(NUM).cuda() - kv = torch.randn((len(dest_loc), H, NOPE_HEAD + ROPE_HEAD), dtype=dtype).cuda() - out = torch.zeros((B * N_CTX, H, NOPE_HEAD + ROPE_HEAD + 2), dtype=torch.uint8).cuda() - - fp8_type = torch.float8_e4m3fn - kv_nope = kv[:, :, :NOPE_HEAD] - kv_rope = kv[:, :, NOPE_HEAD:] - O_nope = out[:, :, :NOPE_HEAD].view(fp8_type) - O_rope = out[:, :, NOPE_HEAD:-2].view(fp8_type) - O_scale = out[:, :, -2:].view(dtype) - destindex_copy_kv_fp8(kv_nope, kv_rope, dest_loc, O_nope, O_rope, O_scale) - - cos1 = F.cosine_similarity(O_nope[:NUM].to(dtype) * O_scale[:NUM], kv_nope).mean() - cos2 = F.cosine_similarity(O_rope[:NUM].to(dtype) * O_scale[:NUM], kv_rope).mean() - assert cos1 > 0.98 - assert cos2 > 0.98 diff --git a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py b/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py deleted file mode 100644 index 72d9d9accc..0000000000 --- a/unit_tests/models/deepseek2/test_gqa_flash_decoding_fp8.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding -from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 -from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.common.req_manager import ReqManager - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, heads, nope_head, rope_head", - [(a, b, c, d, e) for a in [1, 16, 32, 128] for b in [16, 32, 512, 2048] for c in [16] for d in [512] for e in [64]], -) -def test_gqa_flash_decoding_fp8(batch, seqlen, heads, nope_head, rope_head): - Z, N_CTX, H, D_HEAD, ROPE_HEAD = batch, seqlen, heads, nope_head, rope_head - dtype = torch.bfloat16 - sm_scale = 1.0 / ((D_HEAD + ROPE_HEAD) ** 0.5) - q = torch.randn((Z, H, D_HEAD), dtype=dtype, device="cuda") - q_rope = torch.randn((Z, H, ROPE_HEAD), dtype=dtype, device="cuda") - - kv = torch.randn((Z * N_CTX, 1, D_HEAD + ROPE_HEAD), dtype=dtype, device="cuda") - kv_scale = torch.randn((Z * N_CTX, 1, 1), dtype=dtype, device="cuda") - kv_fp8 = kv.to(torch.float8_e4m3fn) - - req_to_token_indexs = torch.zeros((10, Z * N_CTX), dtype=torch.int32, device="cuda") - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") - - b_seq_len[0] = N_CTX - b_req_idx[0] = 0 - req_to_token_indexs[0][:N_CTX] = torch.tensor(np.arange(N_CTX), dtype=torch.int32).cuda() - - o = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - o1 = torch.empty((Z, H, D_HEAD), dtype=dtype, device="cuda") - - infer_state = Deepseek2InferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - - kv_nope = kv_fp8[:, :, :D_HEAD].to(dtype) * kv_scale - kv_rope = kv_fp8[:, :, D_HEAD:].to(dtype) * kv_scale - gqa_token_decode_attention_flash_decoding( - q, - q_rope, - kv_nope, - kv_rope, - infer_state, - H, - D_HEAD, - ROPE_HEAD, - D_HEAD, - sm_scale, - o, - ) - - kv_nope_fp8 = kv_fp8[:, :, :D_HEAD] - kv_rope_fp8 = kv_fp8[:, :, D_HEAD:] - gqa_token_decode_attention_flash_decoding_fp8( - q, q_rope, kv_nope_fp8, kv_rope_fp8, kv_scale, infer_state, H, D_HEAD, ROPE_HEAD, D_HEAD, sm_scale, o1 - ) - - cos_sim = F.cosine_similarity(o, o1).mean() - assert cos_sim > 0.99 diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py deleted file mode 100644 index 737bb655b1..0000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_fp8.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import time -import pytest -import triton as tl -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - - k_cache = kv[:, :KV_HEADS, :] - v_cache = kv[:, KV_HEADS:, :] - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), - # v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), - # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=N_CTX, - # causal=True, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = q_per_head_fp8_quant(q.view(q.shape[0], kv_heads, -1), q_lens, q_starts) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim).to(torch.float8_e4m3fn), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=N_CTX, - causal=True, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1.item() == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5ee2306adf..0000000000 --- a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( - context_attention_fwd, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - -if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant -else: - scaled_fp8_quant = None - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_context_attention_fwd_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 64 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") - if N_CTX > 1: - b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - q_lens = b_seq_len - b_ready_cache_len - q_start_loc = q_lens.cumsum(0) - q_lens - kv_start_loc = b_seq_len.cumsum(0) - b_seq_len - - q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_ready_cache_len = b_ready_cache_len - infer_state.b_start_loc = q_start_loc - - context_attention_fwd( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_len_in_batch, - req_to_token_indexs, - ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - q_starts = torch.zeros((Z + 1,)).int().cuda() - q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) - kv_starts = torch.zeros_like(q_starts) - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_indptr = q_starts.int() - kv_indptr = kv_starts.int() - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, kv_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, - qo_indptr_buf=q_indptr, - paged_kv_indptr_buf=kv_indptr, - paged_kv_indices_buf=kv_indices, - paged_kv_last_page_len_buf=kv_last_page_len_buffer, - ) - kv_last_page_len = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - q_heads, - kv_heads, - head_dim, - page_size, - causal=True, - pos_encoding_mode="NONE", - logits_soft_cap=0.0, - q_data_type=q.dtype, - kv_data_type=torch.float8_e4m3fn, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_context_attention_fwd_flashinfer_fp8(16, 1024, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad.py b/unit_tests/models/llama/test_token_attention_nopad.py deleted file mode 100644 index 1bbb291662..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX - b_start_loc = torch.arange(Z).cuda().int() * N_CTX - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # infer_state.req_manager.req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - non_blocking=True, - ) - kv = kv.unsqueeze(1) - wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - assert cos_sim1 == 1.0 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py deleted file mode 100644 index a7f48ab899..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_fa3_fp8.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): - device = kv_buffer.device - B = seq_lens.size(0) - min_fp8 = torch.finfo(torch.float8_e4m3fn).min - max_fp8 = torch.finfo(torch.float8_e4m3fn).max - _, S_max, H, D = kv_buffer.shape - seq_range = torch.arange(S_max, device=device)[None, :] - valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) - masked = kv_buffer * valid_mask - max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] - scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) - scales_exp = scales.view(B, 1, H, 1) - q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) - return q, scales - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - q_starts = torch.arange(0, Z + 1).int().cuda() - page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) - page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) - - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - # o1 = flash_attn_with_kvcache( - # q=q, - # k_cache=k_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # v_cache=v_cache[page_table].view(-1, N_CTX, kv_heads, head_dim), - # # page_table=page_table, - # cache_seqlens=infer_state.b_seq_len, - # cu_seqlens_q=q_starts, - # cu_seqlens_k_new=kv_starts, - # max_seqlen_q=1, - # causal=False, - # window_size=(-1, -1), - # softcap=0.0, - # return_softmax_lse=False, - # ) - - q, q_scale = scaled_fp8_quant(q.view(batch_size * kv_heads, -1), use_per_token_if_dynamic=True) - k, k_scale = kv_quantize_per_head_fp8(k_cache[page_table], b_seq_len) - v, v_scale = kv_quantize_per_head_fp8(v_cache[page_table], b_seq_len) - o1 = flash_attn_with_kvcache( - q=q.view(-1, q_heads, head_dim), - k_cache=k.view(-1, N_CTX, kv_heads, head_dim), - v_cache=v.view(-1, N_CTX, kv_heads, head_dim), - # page_table=page_table, - cache_seqlens=infer_state.b_seq_len, - cu_seqlens_q=q_starts, - cu_seqlens_k_new=kv_starts, - max_seqlen_q=1, - causal=False, - window_size=(-1, -1), - softcap=0.0, - q_descale=q_scale.view(batch_size, kv_heads), - k_descale=k_scale.view(batch_size, kv_heads), - v_descale=v_scale.view(batch_size, kv_heads), - return_softmax_lse=False, - ) - - # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1 - - -if __name__ == "__main__": - test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py deleted file mode 100644 index 5c0e595b96..0000000000 --- a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_fp8.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import time -import pytest -import numpy as np -import torch.nn.functional as F -import flashinfer -from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd -from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant - -logger = init_logger(__name__) - -seed = 42 -torch.manual_seed(seed) - -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): - from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - total_token_num = infer_state.total_token_num - batch_size = infer_state.batch_size - calcu_shape1 = (batch_size, q_h, h_dim) - - att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() - - token_att_fwd( - q.view(calcu_shape1), - k, - att_m_tensor, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - ) - - from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( - token_softmax_reducev_fwd, - ) - - token_softmax_reducev_fwd( - att_m_tensor, - v, - o, - req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_start_loc, - infer_state.b_seq_len, - ) - return o - - -@pytest.mark.parametrize( - "batch, seqlen, q_heads, kv_heads, head_dim", - [ - (a, b, c, d, e) - for a in [1, 16, 32, 128, 512] - for b in [16, 32, 512, 1024] - for c in [28] - for d in [4] - for e in [128] - ], -) -def test_token_attention_nopad_flashinfer_fp8(batch, seqlen, q_heads, kv_heads, head_dim): - Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim - dtype = torch.bfloat16 - q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - # for i in range(Z * N_CTX): - # kv[i] = torch.randn((2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") * (i % 10 + 1) - - max_input_len = Z * N_CTX - req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX) - b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) - rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") - b_seq_len += rand_num - b_start_loc = b_seq_len.cumsum(0) - b_seq_len - b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() - - o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") - - infer_state = LlamaInferStateInfo() - infer_state.batch_size = Z - infer_state.max_len_in_batch = N_CTX - infer_state.total_token_num = Z * N_CTX - infer_state.b_req_idx = b_req_idx - infer_state.b_seq_len = b_seq_len - infer_state.b_start_loc = b_start_loc - - ref_token_attention_nopad( - q, - kv[:, :KV_HEADS, :], - kv[:, KV_HEADS:, :], - o, - Q_HEADS, - HEAD_DIM, - infer_state, - req_to_token_indexs, - ) - # gqa_decode_attention_fwd( - # q, - # kv[:,:KV_HEADS,:], - # kv[:,KV_HEADS:,:], - # o, - # req_to_token_indexs, - # infer_state.b_req_idx, - # infer_state.b_seq_len, - # ) - - batch_size = Z - head_dim = HEAD_DIM - q_heads = Q_HEADS - kv_heads = KV_HEADS - page_size = 1 - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) - kv_starts = torch.zeros((Z + 1,)).int().cuda() - kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) - kv_indptr = kv_starts - kv_indices = torch.arange(Z * N_CTX).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - kv_indices[start : start + sl] = req_to_token_indexs[b][:sl] - kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) - wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=True, - paged_kv_indptr_buffer=kv_indptr, - paged_kv_indices_buffer=kv_indices, - paged_kv_last_page_len_buffer=kv_last_page_len_buffer, - ) - kv_last_page_len_buffer = torch.full((batch_size,), page_size, dtype=torch.int32) - k_cache = kv[:, :KV_HEADS, :].contiguous() - v_cache = kv[:, KV_HEADS:, :].contiguous() - k, k_scale = scaled_fp8_quant(k_cache.view(1, -1)) - v, v_scale = scaled_fp8_quant(v_cache.view(1, -1)) - wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_len_buffer, - q_heads, - kv_heads, - head_dim, - page_size, - q_data_type=dtype, - kv_data_type=torch.float8_e4m3fn, - non_blocking=True, - ) - wrapper.run( - q, - (k.view(-1, 1, kv_heads, head_dim), v.view(-1, 1, kv_heads, head_dim)), - k_scale=k_scale, - v_scale=v_scale, - out=o1, - return_lse=False, - ) - - cos_sim1 = F.cosine_similarity(o, o1).mean() - print(cos_sim1) - assert cos_sim1 == 1.0 - - -if __name__ == "__main__": - test_token_attention_nopad_flashinfer_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/qwen3-vl/test_deepstack_emb.py b/unit_tests/models/qwen3-vl/test_deepstack_emb.py index 2f929fe0d0..f629a16352 100644 --- a/unit_tests/models/qwen3-vl/test_deepstack_emb.py +++ b/unit_tests/models/qwen3-vl/test_deepstack_emb.py @@ -50,7 +50,7 @@ def test_deepstack_same_image_twice(): deepstack_embs=deepstack_embs, img_token_lens=img_token_lens, img_start_token_ids=img_start_token_ids, - img_start_locs=img_start_locs, + img_start_locs_in_cache=img_start_locs, ) # 7. 看看相同图片两段上的增量 diff --git a/unit_tests/server/core/objs/test_req.py b/unit_tests/server/core/objs/test_req.py index 45fa7967fb..1c946531c1 100644 --- a/unit_tests/server/core/objs/test_req.py +++ b/unit_tests/server/core/objs/test_req.py @@ -1,6 +1,22 @@ import pytest - +import easydict from lightllm.server.core.objs.req import Req, TokenHealingReq, ChunkedPrefillReq, SamplingParams +from lightllm.utils.envs_utils import set_env_start_args + + +@pytest.fixture(scope="module", autouse=True) +def setup_module_env(): + set_env_start_args( + easydict.EasyDict( + { + "mtp_step": 0, + "llm_prefill_att_backend": ["None"], + "llm_decode_att_backend": ["None"], + "cpu_cache_token_page_size": 256, + "enable_cpu_cache": False, + } + ) + ) @pytest.fixture diff --git a/unit_tests/server/core/objs/test_shm_req_manager.py b/unit_tests/server/core/objs/test_shm_req_manager.py index dea40a4859..e26f128d5b 100644 --- a/unit_tests/server/core/objs/test_shm_req_manager.py +++ b/unit_tests/server/core/objs/test_shm_req_manager.py @@ -14,6 +14,11 @@ def setup_env(): running_max_req_size=10, disable_chunked_prefill=True, token_healing_mode=False, + mtp_step=0, + llm_prefill_att_backend=["None"], + llm_decode_att_backend=["None"], + cpu_cache_token_page_size=256, + enable_cpu_cache=False, ) ) # clear the lru_cache if used From b79f44b4bee701e748232586609960efb0a9681e Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sun, 11 Jan 2026 09:25:21 +0000 Subject: [PATCH 06/43] refactor norm and add platform --- lightllm/common/basemodel/basemodel.py | 21 ++- .../layer_weights/meta_weights/__init__.py | 2 +- .../layer_weights/meta_weights/base_weight.py | 33 +--- .../layer_weights/meta_weights/norm_weight.py | 152 +++++++++++++----- .../layer_weights/meta_weights/platform_op.py | 52 ++++++ lightllm/server/api_cli.py | 12 ++ lightllm/utils/device_utils.py | 41 +++++ 7 files changed, 233 insertions(+), 80 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 2e4a188d0a..3e39c8a1b1 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -102,14 +102,8 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() - # 更连续的显存分配可以有更好的性能 - if self.max_total_token_num is None: - self._init_weights() - self._init_mem_manager() - else: - self._init_mem_manager() - self._init_weights() - + self._init_weights() + self._init_mem_manager() self._init_kv_move_buffer() self._check_mem_size() self._init_req_manager() @@ -117,6 +111,7 @@ def __init__(self, kvargs): self._init_some_value() self._init_custom() self._init_inferstate_cls() + self._load_hf_weights() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() self._autotune_warmup() @@ -177,6 +172,16 @@ def _init_weights(self, start_layer_index=0): ] return + def _load_hf_weights(self): + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) + return + def _init_mem_manager(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 self.mem_manager: MemoryManager = select_mem_manager_class()( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 72e0034cb8..1097774013 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -5,7 +5,7 @@ COLMMWeight, ROWBMMWeight, ) -from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight +from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 2cd8ea6ae0..b17da6682c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -13,7 +13,7 @@ def load_hf_weights(self, weights): pass @abstractmethod - def verify_load(self) -> bool: + def _create_weight(self): pass @@ -27,32 +27,5 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to def load_hf_weights(self, weights): raise NotImplementedError("load_hf_weights must implement this method") - def verify_load(self) -> bool: - raise NotImplementedError("verify_load must implement this method") - - def _get_head_tp_split_params(self, weight: torch.Tensor) -> Tuple[int, int]: - """ - Docstring for _get_head_tp_split_params, - 一个常用的tp 划分head获取head_index 范围的功能函数, 一些继承类可能会使用。 - :param self: Description - :param weight: Description - :type weight: torch.Tensor - :return: Description - :rtype: Tuple[int, int] - """ - assert weight.ndim == 2 - - all_head_num = weight.shape[0] - tp_head_num = all_head_num // self.tp_world_size_ - - if tp_head_num > 0: - start_head_index = self.tp_rank_ * tp_head_num - end_head_index = (self.tp_rank_ + 1) * tp_head_num - else: - # 当 tp_world_size 大于 all_head_num 时的特殊处理 - scale_size = self.tp_world_size_ // all_head_num - assert self.tp_world_size_ % all_head_num == 0 - start_head_index = self.tp_rank_ // scale_size - end_head_index = start_head_index + 1 - - return start_head_index, end_head_index + def _create_weight(self) -> bool: + raise NotImplementedError("create_weight must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 619158fa83..de13818a5c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,77 +1,147 @@ import torch -from typing import Optional +from typing import Optional, Dict from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward from lightllm.utils.log_utils import init_logger +from .platform_op import PlatformAwareOp logger = init_logger(__name__) -class NormWeight(BaseWeightTpl): - def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): +class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): super().__init__() - self.norm_dim = norm_dim + self.dim = dim + self.weight_name = weight_name + self.data_type_ = data_type + self._create_weight() + + def _create_weight(self): + self.weight: torch.Tensor = torch.nn.Parameter( + torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + ) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + + def _native_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + x = input.to(torch.float32) + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * self.weight).to(self.data_type_) + if out is None: + out.copy_(x) + return out + return x + + def _cuda_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + if out is None: + out = alloc_func(input.shape, dtype=input.dtype, device=input.device) + return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) + + def apply( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + + +class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__() + self.dim = dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type - self.weight = None - self.bias = None - self.is_weight_ready = False - self.is_bias_ready = False self._create_weight() def _create_weight(self): - device = f"cuda:{get_current_device_id()}" - self.weight = torch.empty(self.norm_dim, dtype=self.data_type_, device=device) - self.bias = ( - torch.empty(self.norm_dim, dtype=self.data_type_, device=device) if self.bias_name is not None else None + self.weight: torch.Tensor = torch.nn.Parameter( + torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + ) + self.bias: torch.Tensor = torch.nn.Parameter( + torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) ) - def load_hf_weights(self, weights): + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) - self.is_weight_ready = True if self.bias_name in weights: self.bias.copy_(weights[self.bias_name]) - self.is_bias_ready = True - def verify_load(self): - return self.is_weight_ready and (self.bias_name is None or self.is_bias_ready) + def _native_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + x = torch.nn.functional.layer_norm( + input, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=eps + ) + if out is None: + out.copy_(x.to(self.data_type_)) + return out + return x.to(self.data_type_) - def rmsnorm_forward( + def _cuda_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - assert input.ndim in [2, 3] and self.weight.ndim == 1 - assert self.bias is None + assert input.ndim == 2 and self.weight.ndim == 1 if out is None: out = alloc_func(input.shape, dtype=input.dtype, device=input.device) - return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) + return layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps, out=out) + def apply( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) -class GEMMANormWeight(NormWeight): - def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): - super().__init__(norm_dim, weight_name, data_type, bias_name) - def load_hf_weights(self, weights): - # TODO: 这里直接 +1 会不会导致精度问题? 计算时要求 (1.0 + weight.float()) ? - if self.weight_name in weights: - self.weight.copy_((weights[self.weight_name] + 1).to(self.data_type_)) - self.is_weight_ready = True +class TpRMSNormWeight(RMSNormWeight): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = get_dp_world_size() + self.tp_rank_ = get_current_rank_in_dp() + self.dim = self._get_tp_padded_dim(dim=dim) + self.repeat_times_ = 1 + def _get_tp_padded_dim(self, dim: int): + """ + Get the padded dimension for the weight. + 1. if dim is divisible by tp_world_size_, return dim + 2. if dim is greater than tp_world_size_, return (dim + tp_world_size_ - 1) // tp_world_size_ * tp_world_size_ + 3. if dim is less than tp_world_size_, assert tp_world_size_ is divisible by dim, and return dim + """ + if dim % self.tp_world_size_ == 0: + return dim // self.tp_world_size_ -class TpNormWeight(NormWeight): - def __init__(self, norm_dim: int, weight_name, data_type, bias_name=None): - super().__init__(norm_dim, weight_name, data_type, bias_name) + if dim > self.tp_world_size_: + return (dim + self.tp_world_size_ - 1) // self.tp_world_size_ * self.tp_world_size_ + else: + assert ( + self.tp_world_size_ % dim == 0 + ), f"tp_world_size_ must be divisible by dim, but found: {self.tp_world_size_} % {dim}" + self.repeat_times_ = self.tp_world_size_ // dim + return dim * self.repeat_times_ // self.tp_world_size_ def load_hf_weights(self, weights): - start = self.norm_dim * self.tp_rank_ - end = self.norm_dim * (self.tp_rank_ + 1) + if self.weight_name in weights and self.weight is None: + t_weight = weights[self.weight_name] + hidden_size = t_weight.shape[0] + split_hidden_size = hidden_size // self.tp_world_size_ - if self.weight_name in weights: - self.weight.copy_(weights[self.weight_name][start:end].to(self.data_type_)) - self.is_weight_ready = True - if self.bias_name in weights: - self.bias.copy_(weights[self.bias_name][start:end].to(self.data_type_)) - self.is_bias_ready = True + start = split_hidden_size * self.tp_rank_ // self.repeat_times_ + end = min(split_hidden_size * (self.tp_rank_ + 1) // self.repeat_times_, hidden_size) + + self.weight[:, end - start].copy_(t_weight[start:end].to(self.data_type_)) + # the padding part is zero + self.weight[:, end:].zero_() diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py new file mode 100644 index 0000000000..127a543b25 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py @@ -0,0 +1,52 @@ +import torch +from abc import ABC, abstractmethod +from typing import Optional, Callable, Any +from lightllm.utils.device_utils import get_platform, Platform +from lightllm.utils.envs_utils import get_env_start_args + + +class PlatformAwareOp(ABC): + """ + platform aware op base class, + automatically route to the corresponding implementation method according to the platform. + """ + + def __init__(self): + args = get_env_start_args() + self.platform = get_platform(args.hardware_platform) + self.enable_torch_naive = args.enable_torch_naive + self._forward = self._route_forward() + + def _route_forward(self) -> Callable: + method_name_map = { + Platform.CUDA: "_cuda_forward", + Platform.ASCEND: "_ascend_forward", + Platform.CAMBRICON: "_cambricon_forward", + Platform.MUSA: "_musa_forward", + Platform.ROCM: "_rocm_forward", + Platform.CPU: "_cpu_forward", + } + + method_name = method_name_map.get(self.platform) + if method_name and hasattr(self, method_name): + method = getattr(self, method_name) + if callable(method): + return method + + if self.enable_torch_naive: + return self._native_forward + + # 如果都没有,抛出异常 + raise NotImplementedError( + f"No implementation found for platform {self.platform.name}. " + f"Please implement _{self.platform.name}_forward method, " + f"or set --enable_torch_naive to use default implementation." + ) + + @abstractmethod + def _native_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("default forward must implement this method") + + @abstractmethod + def _cuda_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("cuda forward must implement this method") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d193bab41c..6f10dde6e8 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -612,4 +612,16 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--hardware_platform", + type=str, + default="cuda", + choices=["cuda", "musa"], + help="""Hardware platform: cuda | musa""", + ) + parser.add_argument( + "--enable_torch_naive", + action="store_true", + help="""Use torch naive implementation for the op.""", + ) return parser diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 09d7a680fe..a1ed6ed950 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -3,6 +3,8 @@ import torch import shutil import subprocess +from enum import Enum +from typing import Optional from functools import lru_cache from lightllm.utils.log_utils import init_logger @@ -284,3 +286,42 @@ def is_5090_gpu() -> bool: return False except: return False + + +class Platform(Enum): + """hardware platform enum""" + + CUDA = "cuda" + ASCEND = "ascend" # ascend + CAMBRICON = "cambricon" # cambricon + MUSA = "musa" # musa + ROCM = "rocm" # rocm + CPU = "cpu" # cpu + + +# 目前仅支持cuda 和 musa +def get_platform(platform_name: Optional[str] = None) -> Platform: + """ + get hardware platform. + + Args: + platform_name: platform name (cuda, ascend, cambricon, musa, rocm, cpu) + + Returns: + Platform: platform enum value + """ + assert platform_name in ["cuda", "musa"], f"Only support cuda and musa now, but got {platform_name}" + platform_name = platform_name.lower() + platform_map = { + "cuda": Platform.CUDA, + "ascend": Platform.ASCEND, + "cambricon": Platform.CAMBRICON, + "musa": Platform.MUSA, + "rocm": Platform.ROCM, + "cpu": Platform.CPU, + } + + platform = platform_map.get(platform_name) + if platform is None: + raise ValueError(f"Unknown platform name: {platform_name}") + return platform From 5740a2ef30dbac0b2f339b4328d5cc72a5ffe813 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 07:19:33 +0000 Subject: [PATCH 07/43] norm --- .../meta_weights/embedding_weight.py | 31 +++++++++------- .../layer_weights/meta_weights/norm_weight.py | 1 + .../pre_and_post_layer_weight.py | 6 ++-- .../layer_weights/transformer_layer_weight.py | 28 ++++++++++----- .../pre_and_post_layer_weight.py | 12 ++++--- .../pre_and_post_layer_weight.py | 7 ++-- .../pre_and_post_layer_weight.py | 6 ++-- .../pre_and_post_layer_weight.py | 7 ++-- .../layer_weights/transformer_layer_weight.py | 17 ++++++--- .../pre_and_post_layer_weight.py | 13 +++---- .../layer_weights/transformer_layer_weight.py | 10 ++++-- .../pre_and_post_layer_weight.py | 6 ++-- .../layer_weights/transformer_layer_weight.py | 16 ++++++--- .../pre_and_post_layer_weight.py | 11 +++--- .../layer_weights/transformer_layer_weight.py | 10 ++++-- .../pre_and_post_layer_weight.py | 6 ++-- .../pre_and_post_layer_weight.py | 7 ++-- .../pre_and_post_layer_weight.py | 6 ++-- .../pre_and_post_layer_weight.py | 7 ++-- .../layer_weights/transformer_layer_weight.py | 36 ++++++++++++++----- 20 files changed, 163 insertions(+), 80 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index fc018267fa..e9b9176dd7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -4,38 +4,45 @@ from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel +from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) class EmbeddingWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type): + def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): super().__init__() + self.dim = dim + self.vocab_size = vocab_size + self.tp_world_size_ = get_dp_world_size() + self.tp_rank_ = get_current_rank_in_dp() + # 计算 split_indexes + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) self.weight_name: str = weight_name self.data_type_ = data_type self.weight: torch.Tensor = None + def _create_weight(self): + tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id + self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights or self.weight is not None: return - t_weight = weights[self.weight_name] # init some params - self.vocab_size = len(t_weight) - split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) - self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) - self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) - + loaded_vocab_size = len(t_weight) + assert ( + loaded_vocab_size == self.vocab_size + ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") - - self.weight = ( + self.weight.copy_( t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_).cuda(get_current_device_id()) ) - def verify_load(self): - return self.weight is not None - def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): if out is None: out = alloc_func( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index de13818a5c..7b966600ce 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -16,6 +16,7 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name self.dim = dim self.weight_name = weight_name self.data_type_ = data_type + assert bias_name is None, "RMSNormWeight does not have bias" self._create_weight() def _create_weight(self): diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index 83f7674531..e02af4b4e2 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -1,18 +1,18 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LayerNormWeight class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.pre_norm_weight_ = NoTpNormWeight( + self.pre_norm_weight_ = LayerNormWeight( weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, bias_name="word_embeddings_layernorm.bias", ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( weight_name="ln_f.weight", data_type=self.data_type_, bias_name="ln_f.bias", diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c5a2d33527..65e00ebe7b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -7,7 +7,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NoTpNormWeight, + RMSNormWeight, FusedMoeWeightEP, ROWBMMWeight, create_tp_moe_wegiht_obj, @@ -299,16 +299,26 @@ def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_ + hidden_size = self.network_config_["hidden_size"] + + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{self.layer_num_}.input_layernorm.weight", + data_type=self.data_type_, ) - self.ffn_norm_weight_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", self.data_type_ + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=f"model.layers.{self.layer_num_}.post_attention_layernorm.weight", + data_type=self.data_type_, ) - self.kv_a_layernorm_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", self.data_type_ + self.kv_a_layernorm_ = RMSNormWeight( + dim=self.kv_lora_rank + self.qk_rope_head_dim, + weight_name=f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", + data_type=self.data_type_, ) if self.q_lora_rank is not None: - self.q_a_layernorm_ = NoTpNormWeight( - f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", self.data_type_ + self.q_a_layernorm_ = RMSNormWeight( + dim=self.q_lora_rank, + weight_name=f"model.layers.{self.layer_num_}.self_attn.q_a_layernorm.weight", + data_type=self.data_type_, ) diff --git a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py index 1f0815c3db..719c80c27f 100644 --- a/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/deepseek_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,7 +2,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ROWMMWeight, ) @@ -11,6 +11,7 @@ class Deepseek3MTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.eh_proj.weight", data_type=self.data_type_, @@ -18,17 +19,20 @@ def __init__(self, data_type, network_config): tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.enorm.weight", data_type=self.data_type_, bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.hnorm.weight", data_type=self.data_type_, bias_name=None, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.shared_head.norm.weight", data_type=self.data_type_, bias_name=None, diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index 3ed7004c12..7419d35e99 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): @@ -7,8 +7,9 @@ def __init__(self, data_type, network_config): super().__init__(data_type, network_config) self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) - - self.final_norm_weight_ = NoTpNormWeight( + hidden_size = network_config["hidden_size"] + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index 59caf40d6b..caef473995 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,11 +1,12 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpNormWeight, ROWMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, RMSNormWeight, ROWMMWeight class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="model.tok_embeddings.weight", data_type=self.data_type_, @@ -17,7 +18,8 @@ def __init__(self, data_type, network_config): tp_rank=0, tp_world_size=1, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 7e9ff41673..82c1f3aa2f 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -1,11 +1,12 @@ from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class LlamaPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", data_type=self.data_type_, @@ -19,9 +20,9 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, - bias_name=None, ) return diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index a455a01f97..426230e144 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight class LlamaTransformerLayerWeight(TransformerLayerWeight): @@ -115,9 +115,16 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + bias_name=self._att_norm_bias_name, ) - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + bias_name=self._ffn_norm_bias_name, ) diff --git a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py index c9032f6fee..a65250b16d 100644 --- a/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/pre_and_post_layer_weight.py @@ -2,7 +2,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ROWMMWeight, ) @@ -10,7 +10,7 @@ class MistralMTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( weight_names="mtp.eh_proj.weight", data_type=self.data_type_, @@ -19,12 +19,13 @@ def __init__(self, data_type, network_config): tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="mtp.enorm.weight", data_type=self.data_type_, - bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="mtp.hnorm.weight", data_type=self.data_type_, bias_name=None, @@ -32,5 +33,5 @@ def __init__(self, data_type, network_config): self.wte_weight_: EmbeddingWeight = None self.lm_head_weight_: LMHeadWeight = None - self.final_norm_weight_: NoTpNormWeight = None + self.final_norm_weight_: RMSNormWeight = None return diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index 08f280b06c..b58e58799a 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight class MistralMTPTransformerLayerWeight(TransformerLayerWeight): @@ -41,6 +41,10 @@ def _init_ffn(self): ) def _init_norm(self): - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + bias_name=self._ffn_norm_bias_name, ) diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index bf9282a979..c35f5c78c9 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -1,12 +1,13 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, @@ -15,7 +16,8 @@ def __init__(self, data_type, network_config): weight_name="lm_head.weight", data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="transformer.ln_f.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 90b7810adf..cbf420f509 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - NoTpNormWeight, + RMSNormWeight, ) @@ -19,6 +19,14 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - - self.q_norm_weight_ = NoTpNormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) - self.k_norm_weight_ = NoTpNormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + hidden_size = self.network_config_["hidden_size"] + self.q_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._q_norm_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._k_norm_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py index 8ba95c1386..e3a557d55c 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/pre_and_post_layer_weight.py @@ -4,7 +4,7 @@ EmbeddingWeight, ROWMMWeight, LMHeadWeight, - NoTpNormWeight, + RMSNormWeight, ) @@ -12,6 +12,7 @@ class Qwen3MOEMTPPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.eh_proj_weight_ = ROWMMWeight( weight_names="model.layers.0.proj.weight", data_type=self.data_type_, @@ -19,12 +20,14 @@ def __init__(self, data_type, network_config): tp_rank=0, tp_world_size=1, ) - self.enorm_weight_ = NoTpNormWeight( + self.enorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.norm_after_embedding.weight", data_type=self.data_type_, bias_name=None, ) - self.hnorm_weight_ = NoTpNormWeight( + self.hnorm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.layers.0.norm_before_output.weight", data_type=self.data_type_, bias_name=None, @@ -32,5 +35,5 @@ def __init__(self, data_type, network_config): # 与Qwen3MOE模型共享 self.wte_weight_: EmbeddingWeight = None self.lm_head_weight_: LMHeadWeight = None - self.final_norm_weight_: NoTpNormWeight = None + self.final_norm_weight_: RMSNormWeight = None return diff --git a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py index 095afecd91..2a11724ced 100644 --- a/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe_mtp/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight class Qwen3MOEMTPTransformerLayerWeight(Qwen3MOETransformerLayerWeight): @@ -16,6 +16,10 @@ def _init_weight(self): self._init_ffn() def _init_norm(self): - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + hidden_size = self.network_config_["hidden_size"] + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + bias_name=self._ffn_norm_bias_name, ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 52a982f495..43758731be 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -1,11 +1,12 @@ import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, @@ -18,7 +19,8 @@ def __init__(self, data_type, network_config): weight_name="lm_head.weight", data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, weight_name="model.language_model.norm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py index 3d044eeb56..885c7ead7d 100755 --- a/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/pre_and_post_layer_weight.py @@ -1,10 +1,13 @@ -from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight, NoTpNormWeight +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import LayerNormWeight class StableLMPreAndPostLayerWeight(LlamaPreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.final_norm_weight_ = NoTpNormWeight( + hidden_size = network_config["hidden_size"] + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name="model.norm.bias", diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 34e74f1367..939c6a1464 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -1,7 +1,7 @@ from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, - NoTpNormWeight, + LayerNormWeight, NoTpPosEmbeddingWeight, LMHeadWeight, ) @@ -12,6 +12,7 @@ def __init__(self, data_type, network_config): super().__init__(data_type, network_config) def _create_weight(self): + hidden_size = self.network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="transformer.wte.weight", data_type=self.data_type_, @@ -21,7 +22,8 @@ def _create_weight(self): data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="transformer.ln_f.weight", bias_name="transformer.ln_f.bias", data_type=self.data_type_, diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index d08c27cc70..7890f82dc6 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,13 +1,13 @@ import torch import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, LayerNormWeight class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - + hidden_size = network_config["hidden_size"] self.wte_weight_ = EmbeddingWeight( weight_name="model.embed_tokens.weight", data_type=self.data_type_, @@ -21,7 +21,8 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) - self.final_norm_weight_ = NoTpNormWeight( + self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name="model.norm.bias", diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index dffcc16fe8..5a7a24a9a8 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -1,4 +1,5 @@ import os +from turtle import TPen import torch import math import numpy as np @@ -7,8 +8,9 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, COLMMWeight, - NoTpNormWeight, - TpVitPadNormWeight, + RMSNormWeight, + LayerNormWeight, + TpRMSNormWeight, ) from lightllm.utils.dist_utils import get_current_device_id @@ -119,16 +121,34 @@ def _init_ffn(self): ) def _init_norm(self): - self.att_norm_weight_ = NoTpNormWeight( - self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + norm_weight_cls = RMSNormWeight if self.norm_type == "rms_norm" else LayerNormWeight + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = norm_weight_cls( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + bias_name=self._att_norm_bias_name, ) - self.ffn_norm_weight_ = NoTpNormWeight( - self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + self.ffn_norm_weight_ = norm_weight_cls( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + bias_name=self._ffn_norm_bias_name, ) if self.qk_norm: head_num = self.network_config_["num_attention_heads"] - self.q_norm_weight_ = TpVitPadNormWeight(self._q_norm_weight_name, self.data_type_, head_num=head_num) - self.k_norm_weight_ = TpVitPadNormWeight(self._k_norm_weight_name, self.data_type_, head_num=head_num) + self.q_norm_weight_ = TpRMSNormWeight( + dim=hidden_size, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + head_num=head_num, + ) + self.k_norm_weight_ = TpRMSNormWeight( + dim=hidden_size, + weight_name=self._k_norm_weight_name, + data_type=self.data_type_, + head_num=head_num, + ) def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.attn.qkv.weight" in weights: From 0e609c8cb9c402da0f02a7e6d51439d65a2fc8da Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 12 Jan 2026 08:24:57 +0000 Subject: [PATCH 08/43] Embedding and LMHead --- .../meta_weights/embedding_weight.py | 144 ++++++++++++++---- .../layer_weights/meta_weights/norm_weight.py | 4 +- .../bloom/layer_infer/post_layer_infer.py | 4 +- .../bloom/layer_infer/pre_layer_infer.py | 6 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../pre_and_post_layer_weight.py | 8 +- .../layer_infer/transformer_layer_infer.py | 30 +--- .../layer_infer/pre_layer_infer.py | 24 +-- .../layer_infer/transformer_layer_infer.py | 26 ++-- .../pre_and_post_layer_weight.py | 4 + .../gemma_2b/layer_infer/pre_layer_infer.py | 10 +- .../pre_and_post_layer_weight.py | 4 + .../pre_and_post_layer_weight.py | 15 +- .../pre_and_post_layer_weight.py | 4 +- .../llama/layer_infer/post_layer_infer.py | 4 +- .../llama/layer_infer/pre_layer_infer.py | 4 +- .../layer_infer/transformer_layer_infer.py | 8 +- .../pre_and_post_layer_weight.py | 16 +- .../layer_infer/pre_layer_infer.py | 36 +---- .../pre_and_post_layer_weight.py | 7 +- .../layer_infer/transformer_layer_infer.py | 16 +- .../pre_and_post_layer_weight.py | 17 ++- .../layer_infer/transformer_layer_infer.py | 12 +- .../starcoder/layer_infer/pre_layer_infer.py | 14 +- .../pre_and_post_layer_weight.py | 8 + .../layer_infer/transformer_layer_infer.py | 12 +- .../pre_and_post_layer_weight.py | 18 +-- 27 files changed, 240 insertions(+), 219 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index e9b9176dd7..d1de857cd7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -2,7 +2,7 @@ import numpy as np from typing import Dict, Optional from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id +from .platform_op import PlatformAwareOp from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp from lightllm.utils.log_utils import init_logger @@ -10,7 +10,7 @@ logger = init_logger(__name__) -class EmbeddingWeight(BaseWeightTpl): +class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): super().__init__() self.dim = dim @@ -23,14 +23,14 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) self.weight_name: str = weight_name self.data_type_ = data_type - self.weight: torch.Tensor = None + self._create_weight() def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name not in weights or self.weight is not None: + if self.weight_name not in weights: return t_weight = weights[self.weight_name] # init some params @@ -39,16 +39,29 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): loaded_vocab_size == self.vocab_size ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") - self.weight.copy_( - t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_).cuda(get_current_device_id()) - ) - - def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + + def _native_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + # Adjust input_ids for tp split + adjusted_ids = input_ids - self.tp_vocab_start_id + # Clamp to valid range for this partition + adjusted_ids = torch.clamp(adjusted_ids, 0, self.weight.shape[0] - 1) + # Use PyTorch native embedding + result = torch.nn.functional.embedding(adjusted_ids, self.weight) + if out is not None: + out.copy_(result) + return out + return result + + def _cuda_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: if out is None: out = alloc_func( (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device ) - embedding_kernel( input_ids=input_ids, weight=self.weight, @@ -56,10 +69,57 @@ def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, vob_end_id=self.tp_vocab_end_id, out=out, ) - return out - def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + def __call__( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + + +class LMHeadWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): + super().__init__() + self.dim = dim + self.vocab_size = vocab_size + self.tp_world_size_ = get_dp_world_size() + self.tp_rank_ = get_current_rank_in_dp() + # 计算 split_indexes + split_indexes = np.linspace(0, self.vocab_size, self.tp_world_size_ + 1, dtype=np.int64) + self.tp_vocab_start_id = int(split_indexes[self.tp_rank_]) + self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) + self.weight_name: str = weight_name + self.data_type_ = data_type + self._create_weight() + + def _create_weight(self): + tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id + self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name not in weights: + return + t_weight = weights[self.weight_name] + loaded_vocab_size = len(t_weight) + assert ( + loaded_vocab_size == self.vocab_size + ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" + logger.info(f"loaded weight vocab_size: {self.vocab_size}") + self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + + def _native_forward( + self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + assert input.ndim == 2 + result = torch.mm(self.weight, input) + if out is not None: + out.copy_(result) + return out + return result + + def _cuda_forward( + self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: assert input.ndim == 2 if out is None: out = alloc_func( @@ -67,49 +127,67 @@ def lm_head(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc dtype=input.dtype, device=input.device, ) - torch.mm(self.weight, input, out=out) return out - -class LMHeadWeight(EmbeddingWeight): - def __init__(self, weight_name, data_type): - super().__init__(weight_name, data_type) + def __call__(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty) -> torch.Tensor: + return self._forward(input=input, out=out, alloc_func=alloc_func) -class NoTpPosEmbeddingWeight(BaseWeightTpl): - def __init__(self, weight_name, data_type): +class NoTpPosEmbeddingWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, data_type: torch.dtype): super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings self.weight_name: str = weight_name self.data_type_ = data_type - self.weight: torch.Tensor = None self.tp_world_size_ = 1 self.tp_rank_ = 0 + self._create_weight() + + def _create_weight(self): + self.weight: torch.Tensor = torch.empty( + self.max_position_embeddings, self.dim, dtype=self.data_type_, device=self.device_id_ + ) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - if self.weight_name not in weights or self.weight is not None: + if self.weight_name not in weights: return - t_weight = weights[self.weight_name] - self.weight = t_weight.to(self.data_type_).cuda(get_current_device_id()) - self.end_position_id: int = t_weight.shape[0] - logger.info(f"loaded weight end_position_id: {self.end_position_id}") - - def verify_load(self): - return self.weight is not None - - def embedding(self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty): + loaded_max_position_embeddings = t_weight.shape[0] + assert ( + loaded_max_position_embeddings == self.max_position_embeddings + ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" + logger.info(f"loaded weight max_position_embeddings: {self.max_position_embeddings}") + self.weight.copy_(t_weight.to(self.data_type_)) + + def _native_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty + ) -> torch.Tensor: + # Use PyTorch native embedding + result = torch.nn.functional.embedding(input_ids, self.weight) + if out is not None: + out.copy_(result) + return out + return result + + def _cuda_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: if out is None: out = alloc_func( (input_ids.shape[0], self.weight.shape[1]), dtype=self.weight.dtype, device=self.weight.device ) - embedding_kernel( input_ids=input_ids, weight=self.weight, vob_start_id=0, - vob_end_id=self.end_position_id, + vob_end_id=self.max_position_embeddings, out=out, ) - return out + + def __call__( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 7b966600ce..16a2e53dae 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -51,7 +51,7 @@ def _cuda_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) - def apply( + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) @@ -101,7 +101,7 @@ def _cuda_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps, out=out) - def apply( + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) diff --git a/lightllm/models/bloom/layer_infer/post_layer_infer.py b/lightllm/models/bloom/layer_infer/post_layer_infer.py index f4fff116cd..ec1d94458f 100644 --- a/lightllm/models/bloom/layer_infer/post_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/post_layer_infer.py @@ -16,6 +16,4 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.final_norm_weight_.layernorm_forward( - input=input, eps=self.eps_, alloc_func=self.alloc_tensor - ) + return layer_weight.final_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) diff --git a/lightllm/models/bloom/layer_infer/pre_layer_infer.py b/lightllm/models/bloom/layer_infer/pre_layer_infer.py index dfe396ab52..e84069e116 100644 --- a/lightllm/models/bloom/layer_infer/pre_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/pre_layer_infer.py @@ -15,17 +15,17 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: BloomPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.pre_norm_weight_.layernorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.pre_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 808788f71a..60d584eebd 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -57,14 +57,14 @@ def _token_attention_kernel( def _att_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( + return layer_weight.att_norm_weight_( input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input: torch.Tensor, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( + return layer_weight.ffn_norm_weight_( input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) diff --git a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py index e02af4b4e2..000a069129 100644 --- a/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LayerNormWeight @@ -7,18 +5,24 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.pre_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="word_embeddings_layernorm.weight", data_type=self.data_type_, bias_name="word_embeddings_layernorm.bias", ) self.final_norm_weight_ = LayerNormWeight( + dim=hidden_size, weight_name="ln_f.weight", data_type=self.data_type_, bias_name="ln_f.bias", ) self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="word_embeddings.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 8695f2de89..801ab6aba2 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -165,20 +165,14 @@ def _get_qkv( q, cache_kv = layer_weight.qkv_a_proj_with_mqa_.mm(input).split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = layer_weight.q_a_layernorm_.rmsnorm_forward( - input=q, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( @@ -208,10 +202,8 @@ def _tpsp_get_qkv( cache_kv = layer_weight.kv_a_proj_with_mqa_.mm(input).view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( q_rope, @@ -244,19 +236,13 @@ def _tpsp_get_qkv( position_sin = infer_state.position_sin q, cache_kv = qkv.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1) - q = layer_weight.q_a_layernorm_.rmsnorm_forward( - q, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + q = layer_weight.q_a_layernorm_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) q = layer_weight.q_b_proj_.mm(q) cache_kv = cache_kv.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim) q = q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim) q_nope, q_rope = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - layer_weight.kv_a_layernorm_.rmsnorm_forward( - cache_kv[:, :, : self.kv_lora_rank], - eps=self.eps_, - out=cache_kv[:, :, : self.kv_lora_rank], + layer_weight.kv_a_layernorm_( + cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) rotary_emb_fwd( q_rope, diff --git a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py index adb749c40e..7e12245587 100644 --- a/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/deepseek_mtp/layer_infer/pre_layer_infer.py @@ -22,16 +22,8 @@ def _mtp_context_forward( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) @@ -43,16 +35,8 @@ def _mtp_token_forward( tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) ans_logics = layer_weight.eh_proj_weight_.mm(cat_embdings) diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 1f386625bf..183c4d8d45 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -37,14 +37,10 @@ def _get_qkv( q = q.view(-1, self.tp_q_head_num_, self.head_dim_) k = cache_kv[:, 0 : self.tp_k_head_num_, :] - q = layer_weight.q_norm_weight_.rmsnorm_forward( - input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(cache_kv.dtype) + q = layer_weight.q_norm_weight_(input=q.float(), eps=self.eps_, alloc_func=self.alloc_tensor).to(cache_kv.dtype) - cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( - input=k.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + cache_kv[:, 0 : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( + input=k.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(cache_kv.dtype) is_sliding = bool((self.layer_num_ + 1) % self.sliding_window_pattern) @@ -92,7 +88,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input1 = layer_weight.pre_feedforward_layernorm_weight_( input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) @@ -101,10 +97,8 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( - input=ffn_out.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) @@ -127,7 +121,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = layer_weight.pre_feedforward_layernorm_weight_.rmsnorm_forward( + input1 = layer_weight.pre_feedforward_layernorm_weight_( input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) @@ -136,10 +130,8 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) - ffn_out = layer_weight.post_feedforward_layernorm_weight_.rmsnorm_forward( - input=ffn_out.float(), - eps=self.eps_, - alloc_func=self.alloc_tensor, + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor ).to(torch.bfloat16) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 858937d8c1..336aa2fc30 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,12 @@ class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="language_model.model.embed_tokens.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py index 468d471d2c..e21788d762 100644 --- a/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/pre_layer_infer.py @@ -22,20 +22,14 @@ def _norm(self, input, infer_state, layer_weight: Gemma_2bPreAndPostLayerWeight) return input * self.normfactor def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding( - input_ids=input_ids, - alloc_func=self.alloc_tensor, - ) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding( - input_ids=input_ids, - alloc_func=self.alloc_tensor, - ) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) input_embdings = self._norm(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index 6e052caa63..fbfb2ee757 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -5,8 +5,12 @@ class Gemma_2bPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py index 7419d35e99..3bb526c79b 100644 --- a/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/pre_and_post_layer_weight.py @@ -5,9 +5,20 @@ class Internlm2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) - self.wte_weight_ = EmbeddingWeight(weight_name="model.tok_embeddings.weight", data_type=self.data_type_) - self.lm_head_weight_ = LMHeadWeight(weight_name="output.weight", data_type=self.data_type_) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.tok_embeddings.weight", + data_type=self.data_type_, + ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="output.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, weight_name="model.norm.weight", diff --git a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py index caef473995..b526192125 100644 --- a/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/internlm2_reward/layer_weights/pre_and_post_layer_weight.py @@ -1,4 +1,3 @@ -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, RMSNormWeight, ROWMMWeight @@ -7,7 +6,10 @@ class Internlm2RewardPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.tok_embeddings.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 8bc10d623c..7714164151 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -19,7 +19,7 @@ def __init__(self, network_config): return def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: - return layer_weight.final_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.final_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: LlamaInferStateInfo): embed_dim_ = input_embdings.shape[1] @@ -66,7 +66,7 @@ def token_forward( input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) - logic_batch = layer_weight.lm_head_weight_.lm_head(input=last_input, alloc_func=self.alloc_tensor) + logic_batch = layer_weight.lm_head_weight_(input=last_input, alloc_func=self.alloc_tensor) last_input = None vocab_size = layer_weight.lm_head_weight_.vocab_size if self.tp_world_size_ == 1: diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index f4f150b173..63a2fe4d14 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -15,13 +15,13 @@ def __init__(self, network_config): return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return input_embdings diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 2a9a543196..820c5efa0d 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -69,16 +69,12 @@ def _token_attention_kernel( def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.rmsnorm_forward(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + return layer_weight.att_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.rmsnorm_forward( - input=input, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) + return layer_weight.ffn_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) def _get_qkv( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 82c1f3aa2f..d240e9ab5b 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -7,18 +7,20 @@ def __init__(self, data_type, network_config): super().__init__(data_type, network_config) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, diff --git a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py index dbe9b61c85..96a15d18a6 100644 --- a/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py +++ b/lightllm/models/mistral_mtp/layer_infer/pre_layer_infer.py @@ -19,22 +19,10 @@ def _mtp_context_forward( input_embdings.shape[0] == tgt_embdings.shape[0] ), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}" - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) - tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + tgt_embdings = layer_weight.final_norm_weight_(input=tgt_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) @@ -47,22 +35,10 @@ def _mtp_token_forward( tgt_embdings = infer_state.mtp_draft_input_hiddens assert input_embdings.shape[0] == tgt_embdings.shape[0] - layer_weight.enorm_weight_.rmsnorm_forward( - input=input_embdings, - eps=self.eps_, - out=input_embdings, - ) + layer_weight.enorm_weight_(input=input_embdings, eps=self.eps_, out=input_embdings) - tgt_embdings = layer_weight.final_norm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - alloc_func=self.alloc_tensor, - ) - layer_weight.hnorm_weight_.rmsnorm_forward( - input=tgt_embdings, - eps=self.eps_, - out=tgt_embdings, - ) + tgt_embdings = layer_weight.final_norm_weight_(input=tgt_embdings, eps=self.eps_, alloc_func=self.alloc_tensor) + layer_weight.hnorm_weight_(input=tgt_embdings, eps=self.eps_, out=tgt_embdings) cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1) diff --git a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py index c35f5c78c9..52d1a54f59 100644 --- a/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight @@ -8,11 +6,16 @@ class QwenPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="transformer.wte.weight", data_type=self.data_type_, ) self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="lm_head.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index c85c423c29..5cd29dcdb7 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -62,13 +62,9 @@ def _get_qkv( q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - layer_weight.q_norm_weight_.rmsnorm_forward( - q.view(-1, self.head_dim_), - eps=self.eps_, - out=q.view(-1, self.head_dim_), - ) + layer_weight.q_norm_weight_(q.view(-1, self.head_dim_), eps=self.eps_, out=q.view(-1, self.head_dim_)) - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( input=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), eps=self.eps_, alloc_func=self.alloc_tensor, @@ -100,13 +96,9 @@ def _tpsp_get_qkv( q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - layer_weight.q_norm_weight_.rmsnorm_forward( - q.view(-1, self.head_dim_), - eps=self.eps_, - out=q.view(-1, self.head_dim_), - ) + layer_weight.q_norm_weight_(q.view(-1, self.head_dim_), eps=self.eps_, out=q.view(-1, self.head_dim_)) - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_.rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), eps=self.eps_, alloc_func=self.alloc_tensor, diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 43758731be..b6c7c50a07 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -1,4 +1,3 @@ -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, RMSNormWeight @@ -7,18 +6,20 @@ class Qwen3VLMOEPreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.language_model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, weight_name="model.language_model.norm.weight", diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index f908dbdd3b..55848ce66a 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -53,17 +53,13 @@ def _tpsp_get_o(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, t def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.att_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.ffn_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) diff --git a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py index 6b88c066ee..b3cd083c30 100644 --- a/lightllm/models/starcoder/layer_infer/pre_layer_infer.py +++ b/lightllm/models/starcoder/layer_infer/pre_layer_infer.py @@ -14,24 +14,18 @@ def __init__(self, network_config): self.layer_norm_eps_ = network_config["layer_norm_epsilon"] def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = layer_weight.wpe_weight_.embedding( - input_ids=infer_state.position_ids, - alloc_func=self.alloc_tensor, - ) + position_embeds = layer_weight.wpe_weight_(input_ids=infer_state.position_ids, alloc_func=self.alloc_tensor) return input_embdings.add_(position_embeds) def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: StarcoderPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_.embedding(input_ids=input_ids, alloc_func=self.alloc_tensor) + input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) - position_embeds = layer_weight.wpe_weight_.embedding( - input_ids=infer_state.position_ids, - alloc_func=self.alloc_tensor, - ) + position_embeds = layer_weight.wpe_weight_(input_ids=infer_state.position_ids, alloc_func=self.alloc_tensor) return input_embdings.add_(position_embeds) diff --git a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py index 939c6a1464..a258480f65 100644 --- a/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/pre_and_post_layer_weight.py @@ -13,11 +13,17 @@ def __init__(self, data_type, network_config): def _create_weight(self): hidden_size = self.network_config["hidden_size"] + vocab_size = self.network_config["vocab_size"] + max_position_embeddings = self.network_config["max_position_embeddings"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="transformer.wte.weight", data_type=self.data_type_, ) self.wpe_weight_ = NoTpPosEmbeddingWeight( + dim=hidden_size, + max_position_embeddings=max_position_embeddings, weight_name="transformer.wpe.weight", data_type=self.data_type_, ) @@ -29,6 +35,8 @@ def _create_weight(self): data_type=self.data_type_, ) self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="lm_head.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 09e3299eb6..3e32682ecb 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -11,19 +11,15 @@ def __init__(self, layer_num, network_config): def _att_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layer_weight.att_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.att_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn_norm( self, input, infer_state: LlamaInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - return layer_weight.ffn_norm_weight_.layernorm_forward( - input=input.view(-1, self.embed_dim_), - eps=self.eps_, - alloc_func=self.alloc_tensor, + return layer_weight.ffn_norm_weight_( + input=input.view(-1, self.embed_dim_), eps=self.eps_, alloc_func=self.alloc_tensor ) def _ffn( diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 7890f82dc6..c5ea7d9224 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,3 @@ -import torch -import numpy as np from lightllm.common.basemodel import PreAndPostLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, LayerNormWeight @@ -8,18 +6,20 @@ class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): def __init__(self, data_type, network_config): super().__init__(data_type, network_config) hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, weight_name="model.embed_tokens.weight", data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - self.lm_head_weight_: LMHeadWeight = self.wte_weight_ - else: - self.lm_head_weight_ = LMHeadWeight( - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = LayerNormWeight( dim=hidden_size, From 39a738b189783af299b49cbd8d64d97268f75c72 Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 12 Jan 2026 08:47:26 +0000 Subject: [PATCH 09/43] fix LMHeadWeight --- .../meta_weights/embedding_weight.py | 26 ++++++++++++++++--- .../pre_and_post_layer_weight.py | 22 +++++++++++----- .../pre_and_post_layer_weight.py | 22 +++++++++++----- .../pre_and_post_layer_weight.py | 22 +++++++++++----- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index d1de857cd7..e228d5c869 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -78,7 +78,14 @@ def __call__( class LMHeadWeight(BaseWeightTpl, PlatformAwareOp): - def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): + def __init__( + self, + dim: int, + vocab_size: int, + weight_name: str, + data_type: torch.dtype, + shared_weight: Optional[EmbeddingWeight] = None, + ): super().__init__() self.dim = dim self.vocab_size = vocab_size @@ -90,13 +97,24 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) self.weight_name: str = weight_name self.data_type_ = data_type - self._create_weight() + self._shared_weight = shared_weight + if shared_weight is None: + self._create_weight() + + @property + def weight(self) -> torch.Tensor: + if self._shared_weight is not None: + return self._shared_weight.weight + return self._weight def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id - self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + self._weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + # When using shared weight, no need to load - EmbeddingWeight already loaded it + if self._shared_weight is not None: + return if self.weight_name not in weights: return t_weight = weights[self.weight_name] @@ -105,7 +123,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): loaded_vocab_size == self.vocab_size ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") - self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self._weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) def _native_forward( self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index d240e9ab5b..2e14eca26f 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -15,12 +15,22 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", - data_type=self.data_type_, - ) + if tie_word_embeddings: + # Share weight with EmbeddingWeight to save memory + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + shared_weight=self.wte_weight_, + ) + else: + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index b6c7c50a07..475bcee95b 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -14,12 +14,22 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.language_model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", - data_type=self.data_type_, - ) + if tie_word_embeddings: + # Share weight with EmbeddingWeight to save memory + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + shared_weight=self.wte_weight_, + ) + else: + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, weight_name="model.language_model.norm.weight", diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index c5ea7d9224..e6d5cb441d 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -14,12 +14,22 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.embed_tokens.weight" if tie_word_embeddings else "lm_head.weight", - data_type=self.data_type_, - ) + if tie_word_embeddings: + # Share weight with EmbeddingWeight to save memory + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + shared_weight=self.wte_weight_, + ) + else: + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + ) self.final_norm_weight_ = LayerNormWeight( dim=hidden_size, From 74adfc50bbcaeec8d88015411f9fb4d0b395a1a1 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 10:05:58 +0000 Subject: [PATCH 10/43] mm weight refactor --- .../layer_weights/meta_weights/__init__.py | 2 +- .../meta_weights/mm_weight/__init__.py | 2 +- .../meta_weights/mm_weight/colmm_weight.py | 10 +- .../meta_weights/mm_weight/mm_slicer.py | 129 ++++++++++-------- .../meta_weights/mm_weight/mm_weight.py | 46 ++----- .../meta_weights/mm_weight/rowmm_weight.py | 53 +++++-- .../layer_weights/transformer_layer_weight.py | 24 ++-- .../layer_weights/transformer_layer_weight.py | 30 ---- 8 files changed, 149 insertions(+), 147 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 1097774013..cbf3998439 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -2,8 +2,8 @@ from .mm_weight import ( MMWeightTpl, ROWMMWeight, + KVROWNMMWeight, COLMMWeight, - ROWBMMWeight, ) from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index 34d989b01f..ae0c651977 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,5 +1,5 @@ from .mm_weight import ( MMWeightTpl, ) -from .rowmm_weight import ROWMMWeight, ROWBMMWeight +from .rowmm_weight import ROWMMWeight, KVROWNMMWeight from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py index bf73b9ad89..1a02e00d0f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py @@ -6,6 +6,7 @@ from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from .mm_slicer import get_col_slice_mixin @@ -21,6 +22,9 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + in_dim = self._get_tp_dim(in_dim) super().__init__( in_dim=in_dim, out_dims=out_dims, @@ -28,9 +32,9 @@ def __init__( data_type=data_type, bias_names=bias_names, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, ) self.param_slicer = get_col_slice_mixin( - self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index e2830ab611..4bc3b44a81 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import Optional, Tuple from abc import ABC, abstractmethod from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size @@ -7,9 +7,12 @@ class SliceMixinBase(ABC): """切片操作的Mixin基类""" - def __init__(self, tp_rank: int = None, tp_world_size: int = None): + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + # this param is used to slice the weight when tp_world_size_ is divisible by the kv_head_num + # for example, if tp_world_size_ is 8 and kv_head_num is 4, then repeat_times_ is 2 + self.repeat_times_ = repeat_times @abstractmethod def _slice_weight(self, weight: torch.Tensor): @@ -19,10 +22,16 @@ def _slice_weight(self, weight: torch.Tensor): def _slice_bias(self, bias): pass + def _get_slice_start_end(self, size: int) -> Tuple[int, int]: + tp_size = size * self.repeat_times_ // self.tp_world_size_ + start = tp_size * (self.tp_rank_ % self.repeat_times_) + end = start + tp_size + return start, end + class SliceMixinTpl(SliceMixinBase): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError("slice_weight must implement this method") @@ -40,113 +49,117 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten # 默认weight 的shape是 outxin,这也是目前最通用的约定。 # 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。 class RowSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: - assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}" - tp_size = weight.shape[0] // self.tp_world_size_ - return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + weight.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[0] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[0]) + return weight[start:end, :] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + bias.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {bias.shape[0] * self.repeat_times_} % {self.tp_world_size_}" + start, end = self._get_slice_start_end(bias.shape[0]) + return bias[start:end] # 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。 # 后续按需要,扩展per-tensor、per-channel的量化方式。 class QuantizedRowSliceMixin(RowSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( weight_scale.shape[0] % self.tp_world_size_ == 0 ), f"tp slice error {weight_scale.shape[0]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[0] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[scale_start:scale_end] + start, end = self._get_slice_start_end(weight_scale.shape[0]) + return weight_scale[start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( weight_zero_point.shape[0] % self.tp_world_size_ == 0 ), f"tp slice error {weight_zero_point.shape[0]} % {self.tp_world_size_}" - tp_size = weight_zero_point.shape[0] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[zero_point_start:zero_point_end] + start, end = self._get_slice_start_end(weight_zero_point.shape[0]) + return weight_zero_point[start:end] class ColSliceMixin(SliceMixinTpl): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor: - assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}" - tp_size = weight.shape[1] // self.tp_world_size_ - return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + weight.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight.shape[1]) + return weight[:, start:end] def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias / self.tp_world_size_ + return bias / self.tp_world_size_ * self.repeat_times_ class QuantizedColSliceMixin(ColSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor: assert ( - weight_scale.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}" - tp_size = weight_scale.shape[1] // self.tp_world_size_ - scale_start = tp_size * self.tp_rank_ - scale_end = tp_size * (self.tp_rank_ + 1) - return weight_scale[:, scale_start:scale_end] + weight_scale.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_scale.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_scale.shape[1]) + return weight_scale[:, start:end] def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Tensor: assert ( - weight_zero_point.shape[1] % self.tp_world_size_ == 0 - ), f"tp slice error {weight_zero_point.shape[1]} % {self.tp_world_size_}" - tp_size = weight_zero_point.shape[1] // self.tp_world_size_ - zero_point_start = tp_size * self.tp_rank_ - zero_point_end = tp_size * (self.tp_rank_ + 1) - return weight_zero_point[:, zero_point_start:zero_point_end] + weight_zero_point.shape[1] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {weight_zero_point.shape[1] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(weight_zero_point.shape[1]) + return weight_zero_point[:, start:end] # awq 的量化权重是inxout存储格式,需要定制实现。 class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}" - tp_size = bias.shape[0] // self.tp_world_size_ - return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] + assert ( + bias.shape[0] * self.repeat_times_ % self.tp_world_size_ == 0 + ), f"tp slice error {bias.shape[0] * self.repeat_times_ } % {self.tp_world_size_}" + start, end = self._get_slice_start_end(bias.shape[0]) + return bias[start:end] class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin): - def __init__(self, tp_rank: int = None, tp_world_size: int = None): - super().__init__(tp_rank, tp_world_size) + def __init__(self, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1): + super().__init__(tp_rank, tp_world_size, repeat_times) def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - return bias / self.tp_world_size_ + return bias / self.tp_world_size_ * self.repeat_times_ -def get_row_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: +def get_row_slice_mixin( + quant_method_name: str, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1 +) -> SliceMixinTpl: if quant_method_name.startswith("awq"): - return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size) + return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) elif quant_method_name == "none": - return RowSliceMixin(tp_rank, tp_world_size) + return RowSliceMixin(tp_rank, tp_world_size, repeat_times) else: - return QuantizedRowSliceMixin(tp_rank, tp_world_size) + return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) -def get_col_slice_mixin(quant_method_name: str, tp_rank: int = None, tp_world_size: int = None) -> SliceMixinTpl: +def get_col_slice_mixin( + quant_method_name: str, tp_rank: int = None, tp_world_size: int = None, repeat_times: int = 1 +) -> SliceMixinTpl: if quant_method_name.startswith("awq"): - return AwqQuantizedColSliceMixin(tp_rank, tp_world_size) + return AwqQuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) elif quant_method_name == "none": - return ColSliceMixin(tp_rank, tp_world_size) + return ColSliceMixin(tp_rank, tp_world_size, repeat_times) else: - return QuantizedColSliceMixin(tp_rank, tp_world_size) + return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 92236b798b..2bb7193c58 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -57,15 +57,6 @@ def __init__( self._create_weight() self.gen_weight_quant_param_names(quant_method=quant_method) - def _create_weight(self): - self.bias = None - if self.bias_names is not None: - self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) - self.mm_param: WeightPack = self.quant_method.create_weight( - in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() - ) - return - def mm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True ) -> torch.Tensor: @@ -133,6 +124,15 @@ def load_hf_weights(self, weights): def verify_load(self) -> bool: return True + def _create_weight(self): + self.bias = None + if self.bias_names is not None: + self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.mm_param: WeightPack = self.quant_method.create_weight( + in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() + ) + return + # 执行顺序 def _load_weight( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int @@ -174,26 +174,8 @@ def _load_weight_zero_point( self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) return - -class BMMWeightTpl(MMWeightTpl): - def mm( - self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True - ) -> torch.Tensor: - raise RuntimeError("use bmm not mm") - - def bmm( - self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True - ) -> torch.Tensor: - # 目前 bmm 不支持量化运算操作 - fpweight = self.mm_param.weight - if out is None: - shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) - dtype = input_tensor.dtype - device = input_tensor.device - if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device) - else: - out = torch.empty(shape, dtype=dtype, device=device) - if self.bias is None: - return torch.bmm(input_tensor, fpweight, out=out) - return torch.addbmm(self.bias, input_tensor, fpweight, out=out) + def _get_tp_dim(self, dim: int) -> int: + assert ( + dim % self.tp_world_size_ == 0 + ), f"dim must be divisible by tp_world_size_, but found: {dim} % {self.tp_world_size_}" + return dim // self.tp_world_size_ diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index e53d643cec..e73b0cecb5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,12 +1,12 @@ import torch from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( MMWeightTpl, - BMMWeightTpl, ) from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod from typing import Dict, List, Optional, Union +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from .mm_slicer import get_row_slice_mixin @@ -22,6 +22,9 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + out_dims = [self._get_tp_dim(out_dim) for out_dim in out_dims] super().__init__( in_dim=in_dim, out_dims=out_dims, @@ -29,17 +32,20 @@ def __init__( bias_names=bias_names, data_type=data_type, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, ) self.param_slicer = get_row_slice_mixin( - self.quant_method.method_name, tp_rank=tp_rank, tp_world_size=tp_world_size + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) -class ROWBMMWeight(BMMWeightTpl): +class KVROWNMMWeight(MMWeightTpl): def __init__( self, + in_dim: int, + kv_head_num: int, + head_dim: int, weight_names: Union[str, List[str]], data_type: torch.dtype, bias_names: Optional[Union[str, List[str]]] = None, @@ -47,13 +53,42 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: + self.tp_rank = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.repeat_times = 1 + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" + ) + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim + out_dims = [kv_hidden_size, kv_hidden_size] super().__init__( + in_dim=in_dim, + out_dims=out_dims, weight_names=weight_names, data_type=data_type, bias_names=bias_names, quant_method=quant_method, - tp_rank=tp_rank, - tp_world_size=tp_world_size, + tp_rank=self.tp_rank, + tp_world_size=self.tp_world_size, ) - # bmm 不支持量化运算操作 - self.param_slicer = get_row_slice_mixin(quant_method_name="none", tp_rank=tp_rank, tp_world_size=tp_world_size) + self.param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank, + tp_world_size=self.tp_world_size, + repeat_times=self.repeat_times, + ) + + def _get_tp_padded_head_num(self, head_num: int): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + elif self.tp_world_size_ % head_num == 0: + self.repeat_times = self.tp_world_size_ // head_num + return self.repeat_times * head_num // self.tp_world_size_ + else: + raise ValueError( + f"head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by head_num, " + f"but found: {head_num} % {self.tp_world_size_}" + ) diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 426230e144..23ecbbabd9 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -2,7 +2,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, RMSNormWeight, KVROWNMMWeight class LlamaTransformerLayerWeight(TransformerLayerWeight): @@ -23,16 +23,15 @@ def _init_weight(self): self._init_norm() def _parse_config(self): - self.tp_q_head_num_ = self.network_config_["num_attention_heads"] // self.tp_world_size_ - self.tp_k_head_num_ = max(self.network_config_["num_key_value_heads"] // self.tp_world_size_, 1) - self.tp_v_head_num_ = self.tp_k_head_num_ - self.tp_o_head_num_ = self.tp_q_head_num_ + self.n_head = self.network_config_["num_attention_heads"] + self.q_head_num_ = self.network_config_["num_attention_heads"] + self.k_head_num_ = self.network_config_["num_key_value_heads"] + self.v_head_num_ = self.k_head_num_ + self.o_head_num_ = self.q_head_num_ head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] self.head_dim = self.network_config_.get("head_dim", head_dim) - assert (self.tp_k_head_num_ * self.tp_world_size_) % self.network_config_["num_key_value_heads"] == 0 self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] - self.n_head = self.network_config_["num_attention_heads"] def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" @@ -62,9 +61,7 @@ def _init_weight_names(self): def _init_qkv(self): in_dim = self.n_embed - q_out_dim = self.tp_q_head_num_ * self.head_dim - k_out_dim = self.tp_k_head_num_ * self.head_dim - v_out_dim = self.tp_v_head_num_ * self.head_dim + q_out_dim = self.q_head_num_ * self.head_dim self.q_proj = ROWMMWeight( in_dim=in_dim, out_dims=[q_out_dim], @@ -73,9 +70,10 @@ def _init_qkv(self): bias_names=self._q_bias_name, quant_method=self.get_quant_method("q_proj"), ) - self.kv_proj = ROWMMWeight( + self.kv_proj = KVROWNMMWeight( in_dim=in_dim, - out_dims=[k_out_dim, v_out_dim], + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], @@ -83,7 +81,7 @@ def _init_qkv(self): ) def _init_o(self): - in_dim = self.tp_o_head_num_ * self.head_dim + in_dim = self.o_head_num_ * self.head_dim out_dim = self.n_embed self.o_proj = COLMMWeight( in_dim=in_dim, diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 74cf6c600f..fe6a5a2d49 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -10,33 +10,3 @@ def _init_weight_names(self): self._q_bias_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" self._k_bias_name = f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" self._v_bias_name = f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" - - def _repeat_weight(self, name, weights): - # for tp_world_size_ > num_key_value_heads - if name not in weights: - return - - tensor = weights[name] - num_kv_heads = self.network_config_["num_key_value_heads"] - repeat_size = (self.tp_k_head_num_ * self.tp_world_size_) // num_kv_heads - - if tensor.ndim == 1: - # Bias (1D tensor) - tensor = tensor.reshape(num_kv_heads, -1).unsqueeze(1).repeat(1, repeat_size, 1).reshape(-1) - else: - # Weight (2D tensor) - tensor = ( - tensor.reshape(num_kv_heads, -1, tensor.shape[-1]) - .unsqueeze(1) - .repeat(1, repeat_size, 1, 1) - .reshape(-1, tensor.shape[-1]) - ) - weights[name] = tensor - - def load_hf_weights(self, weights): - self._repeat_weight(self._k_weight_name, weights) - self._repeat_weight(self._v_weight_name, weights) - if self._k_bias_name is not None and self._v_bias_name is not None: - self._repeat_weight(self._k_bias_name, weights) - self._repeat_weight(self._v_bias_name, weights) - return super().load_hf_weights(weights) From 0cd8fca74262b1be4228ddaa3c29b914ed1d619c Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 12 Jan 2026 10:45:36 +0000 Subject: [PATCH 11/43] MOE --- .../meta_weights/embedding_weight.py | 9 +- .../fused_moe/fused_moe_weight_ep.py | 138 ++++++++++++++--- .../fused_moe/fused_moe_weight_tp.py | 142 +++++++++++------- .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 68 ++++++++- .../layer_weights/meta_weights/norm_weight.py | 6 +- 5 files changed, 274 insertions(+), 89 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index e228d5c869..e3dc0af196 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -12,7 +12,7 @@ class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.dim = dim self.vocab_size = vocab_size self.tp_world_size_ = get_dp_world_size() @@ -24,6 +24,7 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch self.weight_name: str = weight_name self.data_type_ = data_type self._create_weight() + PlatformAwareOp.__init__(self) def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id @@ -86,7 +87,7 @@ def __init__( data_type: torch.dtype, shared_weight: Optional[EmbeddingWeight] = None, ): - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.dim = dim self.vocab_size = vocab_size self.tp_world_size_ = get_dp_world_size() @@ -100,6 +101,7 @@ def __init__( self._shared_weight = shared_weight if shared_weight is None: self._create_weight() + PlatformAwareOp.__init__(self) @property def weight(self) -> torch.Tensor: @@ -154,7 +156,7 @@ def __call__(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, allo class NoTpPosEmbeddingWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, data_type: torch.dtype): - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.dim = dim self.max_position_embeddings = max_position_embeddings self.weight_name: str = weight_name @@ -162,6 +164,7 @@ def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, dat self.tp_world_size_ = 1 self.tp_rank_ = 0 self._create_weight() + PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.empty( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 0923d5dea0..9a4feccdb4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -1,9 +1,9 @@ -import os import torch import threading from typing import Optional, Tuple, List, Dict, Any -from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight +from lightllm.utils.dist_utils import get_global_world_size, get_global_rank +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp from lightllm.common.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, @@ -29,7 +29,7 @@ logger = init_logger(__name__) -class FusedMoeWeightEP(BaseWeight): +class FusedMoeWeightEP(BaseWeightTpl, PlatformAwareOp): def __init__( self, gate_proj_name: str, @@ -44,7 +44,7 @@ def __init__( quant_cfg=None, hidden_size: Optional[int] = None, ) -> None: - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.layer_num = layer_num self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") @@ -63,7 +63,6 @@ def __init__( self.w3_weight_name = up_proj_name self.e_score_correction_bias_name = e_score_correction_bias_name self.n_routed_experts = n_routed_experts - self.data_type_ = data_type self.hidden_size = hidden_size global_world_size = get_global_world_size() @@ -113,6 +112,8 @@ def __init__( if self.hidden_size is not None: self._create_weight() + PlatformAwareOp.__init__(self) + def _create_weight(self): """Pre-allocate GPU memory for fused MoE weights""" if self.hidden_size is None: @@ -126,18 +127,22 @@ def _create_weight(self): # Default fallback - this will be corrected during load intermediate_size = self.hidden_size * 4 - device_id = get_current_device_id() - if not self.quantized_weight and self.quant_method is not None: # Quantized weights w1_pack = self.quant_method.create_weight( - total_expert_num * intermediate_size * 2, self.hidden_size, dtype=self.data_type_, device_id=device_id + total_expert_num * intermediate_size * 2, + self.hidden_size, + dtype=self.data_type_, + device_id=self.device_id_, ) self.w1[0] = w1_pack.weight.view(total_expert_num, intermediate_size * 2, self.hidden_size) self.w1[1] = w1_pack.weight_scale.view(total_expert_num, intermediate_size * 2, self.hidden_size) w2_pack = self.quant_method.create_weight( - total_expert_num * self.hidden_size, intermediate_size, dtype=self.data_type_, device_id=device_id + total_expert_num * self.hidden_size, + intermediate_size, + dtype=self.data_type_, + device_id=self.device_id_, ) self.w2[0] = w2_pack.weight.view(total_expert_num, self.hidden_size, intermediate_size) self.w2[1] = w2_pack.weight_scale.view(total_expert_num, self.hidden_size, intermediate_size) @@ -146,25 +151,18 @@ def _create_weight(self): self.w1[0] = torch.empty( (total_expert_num, intermediate_size * 2, self.hidden_size), dtype=self.data_type_, - device=f"cuda:{device_id}", + device=f"cuda:{self.device_id_}", ) self.w2[0] = torch.empty( (total_expert_num, self.hidden_size, intermediate_size), dtype=self.data_type_, - device=f"cuda:{device_id}", + device=f"cuda:{self.device_id_}", ) - def experts( - self, - input_tensor, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - is_prefill, + def _select_experts( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group ): + """Select experts and return topk weights and ids.""" topk_weights, topk_ids = select_experts( hidden_states=input_tensor, router_logits=router_logits, @@ -187,6 +185,74 @@ def experts( expert_counter=self.routed_expert_counter_tensor, enable_counter=self.auto_update_redundancy_expert, ) + return topk_weights, topk_ids + + def _native_forward( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + is_prefill, + ): + """PyTorch native implementation for EP MoE forward pass.""" + topk_weights, topk_ids = self._select_experts( + input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ) + + w1, w1_scale = self.w1 + w2, w2_scale = self.w2 + + # Native PyTorch implementation (less optimized but works on all platforms) + batch_size, hidden_size = input_tensor.shape + intermediate_size = w1.shape[1] // 2 + + output = torch.zeros_like(input_tensor) + + for i in range(batch_size): + expert_output = torch.zeros(hidden_size, dtype=input_tensor.dtype, device=input_tensor.device) + for j in range(top_k): + expert_idx = topk_ids[i, j].item() + weight = topk_weights[i, j] + + # Get local expert index (EP mode uses local expert indices) + local_expert_idx = expert_idx % self.ep_load_expert_num + + # Get expert weights + w1_expert = w1[local_expert_idx, :intermediate_size, :] # gate + w3_expert = w1[local_expert_idx, intermediate_size:, :] # up + w2_expert = w2[local_expert_idx] + + # Compute: SiLU(x @ w1.T) * (x @ w3.T) @ w2.T + x = input_tensor[i : i + 1] + gate = torch.nn.functional.silu(torch.mm(x, w1_expert.T)) + up = torch.mm(x, w3_expert.T) + hidden = gate * up + expert_out = torch.mm(hidden, w2_expert.T) + expert_output += weight * expert_out.squeeze(0) + + output[i] = expert_output + + return output + + def _cuda_forward( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + is_prefill, + ): + """CUDA optimized implementation for EP MoE forward pass.""" + topk_weights, topk_ids = self._select_experts( + input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ) w1, w1_scale = self.w1 w2, w2_scale = self.w2 @@ -207,6 +273,29 @@ def experts( previous_event=None, # for overlap ) + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + is_prefill, + ): + """Backward compatible method that routes to platform-specific implementation.""" + return self._forward( + input_tensor=input_tensor, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + is_prefill=is_prefill, + ) + def low_latency_dispatch( self, hidden_states: torch.Tensor, @@ -651,10 +740,9 @@ def _copy_expert_scales(self, target_idx, expert_id, weights): self.w2[1][target_idx].copy_(w2_scale_tensor) def _cuda(self, cpu_tensor): - device_id = get_current_device_id() if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) + return cpu_tensor.contiguous().cuda(self.device_id_) + return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) def verify_load(self): return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index bf7b218b71..d30475444d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -1,9 +1,7 @@ -import os import torch -import threading -from typing import Tuple, List, Dict, Any, Union, Callable -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeight -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id, get_dp_world_size +from typing import Dict, Any, Union +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.quantize_method import WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( @@ -59,7 +57,7 @@ def create_tp_moe_wegiht_obj( ) -class FusedMoeWeightTP(BaseWeight): +class FusedMoeWeightTP(BaseWeightTpl, PlatformAwareOp): def __init__( self, gate_proj_name: str, @@ -75,7 +73,7 @@ def __init__( layer_num: int, quant_cfg: Quantcfg = None, ) -> None: - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method.method_name != "none": @@ -92,48 +90,49 @@ def __init__( self.num_fused_shared_experts = num_fused_shared_experts self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.split_inter_size = split_inter_size - self.data_type_ = data_type self.hidden_size = network_config.get("hidden_size") - self.tp_rank_ = get_current_rank_in_dp() self.e_score_correction_bias = None self.scoring_func = network_config.get("scoring_func", "softmax") self.row_slicer = get_row_slice_mixin( - self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) self.col_slicer = get_col_slice_mixin( - self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=get_dp_world_size() + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) self._create_weight() + PlatformAwareOp.__init__(self) def _create_weight(self): total_expert_num = self.n_routed_experts intermediate_size = self.split_inter_size - device_id = get_current_device_id() # Create e_score_correction_bias if self.e_score_correction_bias is not None: self.e_score_correction_bias = torch.empty( (total_expert_num,), dtype=self.data_type_, - device=f"cuda:{device_id}", + device=f"cuda:{self.device_id_}", ) self.w13: WeightPack = self.quant_method.create_weight( out_dim=intermediate_size * 2, in_dim=self.hidden_size, dtype=self.data_type_, - device_id=device_id, + device_id=self.device_id_, num_experts=total_expert_num, ) self.w2: WeightPack = self.quant_method.create_weight( out_dim=self.hidden_size, in_dim=intermediate_size, dtype=self.data_type_, - device_id=device_id, + device_id=self.device_id_, num_experts=total_expert_num, ) - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def _select_experts( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """Select experts and return topk weights and ids.""" from lightllm.common.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -169,6 +168,53 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + return topk_weights, topk_ids + + def _native_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + topk_weights, topk_ids = self._select_experts( + input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ) + + w13, _ = self.w13.weight, self.w13.weight_scale + w2, _ = self.w2.weight, self.w2.weight_scale + + batch_size, hidden_size = input_tensor.shape + intermediate_size = w13.shape[1] // 2 + + output = torch.zeros_like(input_tensor) + + for i in range(batch_size): + expert_output = torch.zeros(hidden_size, dtype=input_tensor.dtype, device=input_tensor.device) + for j in range(top_k): + expert_idx = topk_ids[i, j].item() + weight = topk_weights[i, j] + + w1 = w13[expert_idx, :intermediate_size, :] # gate + w3 = w13[expert_idx, intermediate_size:, :] # up + w2_expert = w2[expert_idx] + + # Compute: SiLU(x @ w1.T) * (x @ w3.T) @ w2.T + x = input_tensor[i : i + 1] + gate = torch.nn.functional.silu(torch.mm(x, w1.T)) + up = torch.mm(x, w3.T) + hidden = gate * up + expert_out = torch.mm(hidden, w2_expert.T) + expert_output += weight * expert_out.squeeze(0) + + output[i] = expert_output + + input_tensor.copy_(output) + return + + def _cuda_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """CUDA optimized implementation of MoE forward pass.""" + topk_weights, topk_ids = self._select_experts( + input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ) w13, w13_scale = self.w13.weight, self.w13.weight_scale w2, w2_scale = self.w2.weight, self.w2.weight_scale @@ -189,11 +235,22 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t ) return + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + """Backward compatible method that routes to platform-specific implementation.""" + return self._forward( + input_tensor=input_tensor, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + ) + def _cuda(self, cpu_tensor): - device_id = get_current_device_id() if self.quantized_weight: - return cpu_tensor.cuda(device_id) - return cpu_tensor.cuda(device_id) + return cpu_tensor.cuda(self.device_id_) + return cpu_tensor.cuda(self.device_id_) def verify_load(self): return True @@ -259,42 +316,19 @@ def __init__(self, *args, **kwargs): self.workspace = marlin_make_workspace_new(self.w13.weight.device, 4) - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, + def _native_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """AWQ Marlin quantization requires CUDA, native forward not supported.""" + raise NotImplementedError("AWQ Marlin MoE requires CUDA platform, native forward not supported.") + + def _cuda_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """CUDA optimized implementation using AWQ Marlin kernels.""" + topk_weights, topk_ids = self._select_experts( + input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) w1, w1_scale, w1_zero_point = self.w13.weight, self.w13.weight_scale, self.w13.weight_zero_point w2, w2_scale, w2_zero_point = self.w2.weight, self.w2.weight_scale, self.w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 9d79ff7c25..9821b5ad66 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -1,10 +1,7 @@ -import os import torch -import threading -from typing import Optional, Tuple, List, Dict, Any +from typing import Dict, Any from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_tp import FusedMoeWeightTP -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger @@ -121,7 +118,56 @@ def router(self, router_logits, top_k): router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_top_value, router_indices - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def _native_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """PyTorch native implementation for GPT-OSS MoE forward pass.""" + topk_weights, topk_ids = self.router(router_logits, top_k) + + w1, w1_scale = self.w1 + w2, w2_scale = self.w2 + + batch_size, hidden_size = input_tensor.shape + + output = torch.zeros_like(input_tensor) + input_bf16 = input_tensor.to(torch.bfloat16) + + for i in range(batch_size): + expert_output = torch.zeros(hidden_size, dtype=torch.bfloat16, device=input_tensor.device) + for j in range(top_k): + expert_idx = topk_ids[i, j].item() + weight = topk_weights[i, j] + + w1_expert = w1[expert_idx] + w2_expert = w2[expert_idx] + + x = input_bf16[i : i + 1] + hidden = torch.mm(x, w1_expert.T) # [1, intermediate_size * 2] + if self.w1_bias is not None: + hidden = hidden + self.w1_bias[expert_idx : expert_idx + 1] + + gate = hidden[:, 0::2] + up = hidden[:, 1::2] + + gate = torch.clamp(gate * self.alpha, -self.limit, self.limit) + gate = torch.nn.functional.sigmoid(gate) + hidden = gate * up + + expert_out = torch.mm(hidden, w2_expert.T) + if self.w2_bias is not None: + expert_out = expert_out + self.w2_bias[expert_idx : expert_idx + 1] / self.tp_world_size_ + + expert_output += weight * expert_out.squeeze(0) + + output[i] = expert_output + + input_tensor.copy_(output.to(input_tensor.dtype)) + return output + + def _cuda_forward( + self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group + ): + """CUDA optimized implementation for GPT-OSS MoE forward pass.""" topk_weights, topk_ids = self.router(router_logits, top_k) w1, w1_scale = self.w1 @@ -148,6 +194,18 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t ) return output_tensor + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + """Backward compatible method that routes to platform-specific implementation.""" + return self._forward( + input_tensor=input_tensor, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + ) + def _convert_moe_packed_tensors( self, blocks, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 16a2e53dae..6f2bf4e363 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -12,12 +12,13 @@ class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.dim = dim self.weight_name = weight_name self.data_type_ = data_type assert bias_name is None, "RMSNormWeight does not have bias" self._create_weight() + PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.nn.Parameter( @@ -59,12 +60,13 @@ def __call__( class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): - super().__init__() + BaseWeightTpl.__init__(self, data_type=data_type) self.dim = dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type self._create_weight() + PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.nn.Parameter( From cbd1726f6d2a79d66046be678b08d1a511047865 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 11:25:25 +0000 Subject: [PATCH 12/43] fix gemma norm & slicer --- .../layer_weights/meta_weights/__init__.py | 2 +- .../layer_weights/meta_weights/base_weight.py | 2 ++ .../meta_weights/mm_weight/mm_slicer.py | 2 +- .../meta_weights/mm_weight/rowmm_weight.py | 12 ++++---- .../layer_weights/meta_weights/norm_weight.py | 24 +++++++++------ lightllm/common/quantization/no_quant.py | 2 +- .../layer_weights/transformer_layer_weight.py | 29 +++++++++---------- .../pre_and_post_layer_weight.py | 1 + .../layer_weights/transformer_layer_weight.py | 19 +++++++++--- .../layer_weights/transformer_layer_weight.py | 8 +++-- .../layer_weights/transformer_layer_weight.py | 10 +++---- .../layer_weights/transformer_layer_weight.py | 5 ++-- 12 files changed, 68 insertions(+), 48 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index cbf3998439..47bf7c05f5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -5,7 +5,7 @@ KVROWNMMWeight, COLMMWeight, ) -from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight +from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index b17da6682c..58860ab30e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -6,6 +6,7 @@ class BaseWeight(ABC): def __init__(self): + super().__init__() pass @abstractmethod @@ -19,6 +20,7 @@ def _create_weight(self): class BaseWeightTpl(BaseWeight): def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: torch.dtype = None): + super().__init__() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.device_id_ = get_current_device_id() diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index 4bc3b44a81..ddbf98a866 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -24,7 +24,7 @@ def _slice_bias(self, bias): def _get_slice_start_end(self, size: int) -> Tuple[int, int]: tp_size = size * self.repeat_times_ // self.tp_world_size_ - start = tp_size * (self.tp_rank_ % self.repeat_times_) + start = tp_size * (self.tp_rank_ // self.repeat_times_) end = start + tp_size return start, end diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index e73b0cecb5..d7554b3757 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -53,8 +53,8 @@ def __init__( tp_rank: int = None, tp_world_size: int = None, ) -> None: - self.tp_rank = tp_rank if tp_rank is not None else get_current_rank_in_dp() - self.tp_world_size = tp_world_size if tp_world_size is not None else get_dp_world_size() + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.repeat_times = 1 assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( f"kv_head_num must be divisible by tp_world_size_ or " @@ -70,13 +70,13 @@ def __init__( data_type=data_type, bias_names=bias_names, quant_method=quant_method, - tp_rank=self.tp_rank, - tp_world_size=self.tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, ) self.param_slicer = get_row_slice_mixin( self.quant_method.method_name, - tp_rank=self.tp_rank, - tp_world_size=self.tp_world_size, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, repeat_times=self.repeat_times, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 16a2e53dae..0232016103 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -20,9 +20,7 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name self._create_weight() def _create_weight(self): - self.weight: torch.Tensor = torch.nn.Parameter( - torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - ) + self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: @@ -67,12 +65,8 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name self._create_weight() def _create_weight(self): - self.weight: torch.Tensor = torch.nn.Parameter( - torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - ) - self.bias: torch.Tensor = torch.nn.Parameter( - torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - ) + self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: @@ -146,3 +140,15 @@ def load_hf_weights(self, weights): self.weight[:, end - start].copy_(t_weight[start:end].to(self.data_type_)) # the padding part is zero self.weight[:, end:].zero_() + + +class NoTpGEMMANormWeight(RMSNormWeight): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight += 1 diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index f342607c10..987601c5d6 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -23,7 +23,7 @@ def apply( dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device, is_graph_out=False) + out = g_cache_manager.alloc_tensor(shape, dtype, device=device) else: out = torch.empty(shape, dtype=dtype, device=device) if bias is None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 65e00ebe7b..6ff081eef4 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -9,7 +9,6 @@ COLMMWeight, RMSNormWeight, FusedMoeWeightEP, - ROWBMMWeight, create_tp_moe_wegiht_obj, ) from functools import partial @@ -176,20 +175,20 @@ def _init_qkvo(self): layer_num=self.layer_num_, name="q_b_proj", ) - self.k_b_proj_ = ROWBMMWeight( - weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", - data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="k_b_proj", - ) - self.v_b_proj_ = ROWBMMWeight( - weight_names=f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", - data_type=self.data_type_, - quant_cfg=None, - layer_num=self.layer_num_, - name="v_b_proj", - ) + # self.k_b_proj_ = ROWBMMWeight( + # weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + # data_type=self.data_type_, + # quant_cfg=None, + # layer_num=self.layer_num_, + # name="k_b_proj", + # ) + # self.v_b_proj_ = ROWBMMWeight( + # weight_names=f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + # data_type=self.data_type_, + # quant_cfg=None, + # layer_num=self.layer_num_, + # name="v_b_proj", + # ) if self.enable_cc_method: self.cc_kv_b_proj_ = ROWMMWeight( weight_names=f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight", diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 336aa2fc30..7ae0fbcca3 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -17,6 +17,7 @@ def __init__(self, data_type, network_config): self.lm_head_weight_ = self.wte_weight_ self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, weight_name="language_model.model.norm.weight", data_type=self.data_type_, bias_name=None, diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index e7808c412c..11e3c2f36f 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -63,13 +63,24 @@ def _init_qkv(self): def _init_norm(self): super()._init_norm() - self.k_norm_weight_ = NoTpGEMMANormWeight(self._k_norm_weight_name, self.data_type_, bias_name=None) - self.q_norm_weight_ = NoTpGEMMANormWeight(self._q_norm_weight_name, self.data_type_, bias_name=None) + + self.k_norm_weight_ = NoTpGEMMANormWeight( + dim=self.head_dim_, weight_name=self._k_norm_weight_name, data_type=self.data_type_, bias_name=None + ) + self.q_norm_weight_ = NoTpGEMMANormWeight( + dim=self.head_dim_, weight_name=self._q_norm_weight_name, data_type=self.data_type_, bias_name=None + ) self.pre_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( - self._pre_feedforward_layernorm_name, self.data_type_, bias_name=None + dim=self.n_embed, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + bias_name=None, ) self.post_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( - self._post_feedforward_layernorm_name, self.data_type_, bias_name=None + dim=self.n_embed, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + bias_name=None, ) def load_hf_weights(self, weights): diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 9102ce6775..87b2fb744d 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -29,5 +29,9 @@ def _init_qkv(self): ) def _init_norm(self): - self.att_norm_weight_ = NoTpGEMMANormWeight(self._att_norm_weight_name, self.data_type_) - self.ffn_norm_weight_ = NoTpGEMMANormWeight(self._ffn_norm_weight_name, self.data_type_) + self.att_norm_weight_ = NoTpGEMMANormWeight( + dim=self.n_embed, weight_name=self._att_norm_weight_name, data_type=self.data_type_, bias_name=None + ) + self.ffn_norm_weight_ = NoTpGEMMANormWeight( + dim=self.n_embed, weight_name=self._ffn_norm_weight_name, data_type=self.data_type_, bias_name=None + ) diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 23ecbbabd9..b68903ecd1 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -93,19 +93,17 @@ def _init_o(self): ) def _init_ffn(self): - in_dim = self.n_embed - out_dim = self.n_inter // self.tp_world_size_ self.gate_up_proj = ROWMMWeight( - in_dim=in_dim, - out_dims=[out_dim, out_dim], + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( - in_dim=out_dim, - out_dims=[in_dim], + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index cbf420f509..014f4f6ac2 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -19,14 +19,13 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - hidden_size = self.network_config_["hidden_size"] self.q_norm_weight_ = RMSNormWeight( - dim=hidden_size, + dim=self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_, ) self.k_norm_weight_ = RMSNormWeight( - dim=hidden_size, + dim=self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_, ) From 12ddd8ae53415882608fe6c76fffb96fe2e25572 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 11:57:48 +0000 Subject: [PATCH 13/43] fix --- .../basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 2bb7193c58..a7288b8187 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -29,7 +29,6 @@ def __init__( tp_world_size: int = None, ) -> None: super().__init__(tp_rank, tp_world_size, data_type) - self.lock = threading.Lock() self.in_dim = in_dim if isinstance(out_dims, int): From ade109f5ddbfa08345e3341bba858fcaf2016f78 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 12:09:08 +0000 Subject: [PATCH 14/43] remove data_type --- .../basemodel/layer_weights/meta_weights/norm_weight.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index b132593532..0232016103 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -12,13 +12,12 @@ class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.dim = dim self.weight_name = weight_name self.data_type_ = data_type assert bias_name is None, "RMSNormWeight does not have bias" self._create_weight() - PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) @@ -58,13 +57,12 @@ def __call__( class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.dim = dim self.weight_name = weight_name self.bias_name = bias_name self.data_type_ = data_type self._create_weight() - PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) From 45a415a6d51f09fa8835c560a7a1d64ae5a53d97 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 12:10:20 +0000 Subject: [PATCH 15/43] remove fused_moe_weight_tp --- .../meta_weights/fused_moe_weight_tp.py | 669 ------------------ 1 file changed, 669 deletions(-) delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py deleted file mode 100644 index 9295fa96ae..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ /dev/null @@ -1,669 +0,0 @@ -import os -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any, Union -from .base_weight import BaseWeight -from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id -from lightllm.common.quantization import Quantcfg - - -def create_tp_moe_wegiht_obj( - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, -) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: - quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is not None and quant_method.method_name == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - else: - return FusedMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - - -class FusedMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - use_fp8_w8a8 = self.quant_method is not None - - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts - - fused_experts( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w1_scale, - w2_scale=w2_scale, - ) - return - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.n_routed_experts - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.n_routed_experts): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1) - qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2) - self.w1[0] = qw1 - self.w1[1] = qw1_scale - self.w2[0] = qw2 - self.w2[1] = qw2_scale - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.n_routed_experts - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if self.quant_method is not None: - self._load_weight_scale(weights) - self._fuse() - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - :, - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - :, - self.split_inter_size - // block_size - * self.tp_rank_ : self.split_inter_size - // block_size - * (self.tp_rank_ + 1), - ] - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(device_id) - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None - - -class FusedAWQMARLINMoeWeightTP(BaseWeight): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.weight_zero_point_suffix = self.quant_method.weight_zero_point_suffix - self.quant_method.is_moe = True - hf_quantization_config = network_config.get("quantization_config", None) - self.num_bits = hf_quantization_config.get("bits", 4) - self.group_size = hf_quantization_config.get("group_size", 128) - self.pack_factor = 32 // self.num_bits - self.has_processed_weight = False - assert self.quant_method.method_name == "awq_marlin" - - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.tp_rank_ = get_current_rank_in_dp() - self.experts_up_projs = [None] * self.n_routed_experts - self.experts_gate_projs = [None] * self.n_routed_experts - self.experts_up_proj_scales = [None] * self.n_routed_experts - self.experts_up_proj_zero_points = [None] * self.n_routed_experts - self.experts_gate_proj_scales = [None] * self.n_routed_experts - self.experts_gate_proj_zero_points = [None] * self.n_routed_experts - self.e_score_correction_bias = None - self.w2_list = [None] * self.n_routed_experts - self.w2_scale_list = [None] * self.n_routed_experts - self.w2_zero_point_list = [None] * self.n_routed_experts - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None, None] # weight, weight_scale, zero_point - self.w2 = [None, None, None] # weight, weight_scale, zero_point - self.lock = threading.Lock() - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - - w1, w1_scale, w1_zero_point = self.w1 - w2, w2_scale, w2_zero_point = self.w2 - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - - fused_marlin_moe( - input_tensor, - w1, - w2, - None, - None, - w1_scale, - w2_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=self.quant_method.vllm_quant_type.id, - apply_router_weight_on_input=False, - global_num_experts=-1, - expert_map=None, - w1_zeros=w1_zero_point, - w2_zeros=w2_zero_point, - workspace=self.workspace, - inplace=True, - ) - - return - - def _fuse(self): - self._fuse_weight() - self._fuse_weight_scale() - self._fuse_weight_zero_point() - - def _fuse_weight(self): - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_projs[0].shape - up_in_dim, up_out_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - - w1 = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - - for i_experts in range(self.n_routed_experts): - w1[i_experts, :, 0:gate_out_dim] = self.experts_gate_projs[i_experts] - w1[i_experts, :, gate_out_dim:] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_scales[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_scales[0].shape - dtype = self.experts_gate_proj_scales[0].dtype - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_scale = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=dtype, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_scale[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, :, gate_out_dim:] = self.experts_up_proj_scales[i_experts] - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale).to(self.data_type_) - self.w2[1] = self._cuda(w2_scale).to(self.data_type_) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def _fuse_weight_zero_point(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_zero_points") - and None not in self.experts_up_proj_zero_points - and None not in self.experts_gate_proj_zero_points - and None not in self.w2_zero_point_list - ): - gate_in_dim, gate_out_dim = self.experts_gate_proj_zero_points[0].shape - up_in_dim, up_out_dim = self.experts_up_proj_zero_points[0].shape - assert gate_in_dim == up_in_dim - total_expert_num = self.n_routed_experts - w1_zero_point = torch.empty( - (total_expert_num, gate_in_dim, gate_out_dim + up_out_dim), dtype=torch.int32, device="cpu" - ) - for i_experts in range(self.n_routed_experts): - w1_zero_point[i_experts, :, 0:gate_out_dim] = self.experts_gate_proj_zero_points[i_experts] - w1_zero_point[i_experts, :, gate_out_dim:] = self.experts_up_proj_zero_points[i_experts] - inter_shape, hidden_size = self.w2_zero_point_list[0].shape[0], self.w2_zero_point_list[0].shape[1] - w2_zero_point = torch._utils._flatten_dense_tensors(self.w2_zero_point_list).view( - len(self.w2_zero_point_list), inter_shape, hidden_size - ) - self.w1[2] = self._cuda(w1_zero_point) - self.w2[2] = self._cuda(w2_zero_point) - delattr(self, "w2_zero_point_list") - delattr(self, "experts_up_proj_zero_points") - delattr(self, "experts_gate_proj_zero_points") - - def load_hf_weights(self, weights): - self._load_weight(weights) - self._load_weight_scale(weights) - self._load_weight_zero_point(weights) - self._fuse() - self._process_weight_after_loading() - - def _load_weight(self, weights: Dict[str, torch.Tensor]) -> None: - # awq quantization weight shape: in x out - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - for i_experts in range(self.n_routed_experts): - w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.qweight" - w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.qweight" - w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.qweight" - - if w1_weight in weights: - self.experts_gate_projs[i_experts] = weights[w1_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - if w3_weight in weights: - self.experts_up_projs[i_experts] = weights[w3_weight][ - :, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1) - ] - - if w2_weight in weights: - self.w2_list[i_experts] = weights[w2_weight][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), : - ] - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - split_inter_size = self.split_inter_size * self.pack_factor - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts] = weights[w1_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts] = weights[w3_scale][ - :, - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - ] - - if w2_scale in weights: - self.w2_scale_list[i_experts] = weights[w2_scale][ - split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _load_weight_zero_point(self, weights: Dict[str, torch.Tensor]) -> None: - for i_experts in range(self.n_routed_experts): - w1_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_zero_point_suffix}" - w2_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_zero_point_suffix}" - w3_zero_point = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_zero_point_suffix}" - if w1_zero_point in weights: - self.experts_gate_proj_zero_points[i_experts] = weights[w1_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w3_zero_point in weights: - self.experts_up_proj_zero_points[i_experts] = weights[w3_zero_point][ - :, - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - ] - if w2_zero_point in weights: - self.w2_zero_point_list[i_experts] = weights[w2_zero_point][ - self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), - :, - ] - - def _process_weight_after_loading(self): - with self.lock: - if None in self.w1 or None in self.w2 or self.has_processed_weight: - return - self.has_processed_weight = True - from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - - assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" - - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_moe_permute_scales, - moe_awq_to_marlin_zero_points, - marlin_make_workspace_new, - ) - - num_experts = self.n_routed_experts - device = self.w1[0].device - - self.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - self.w1[0] = vllm_ops.awq_marlin_moe_repack( - self.w1[0], - self.w13_g_idx_sort_indices, - size_k=self.w1[0].shape[1], - size_n=self.w1[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[0] = vllm_ops.awq_marlin_moe_repack( - self.w2[0], - self.w2_g_idx_sort_indices, - size_k=self.w2[0].shape[1], - size_n=self.w2[0].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - # Why does this take the intermediate size for size_k? - self.w1[1] = marlin_moe_permute_scales( - s=self.w1[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w1[1].shape[2], - group_size=self.group_size, - ) - - self.w2[1] = marlin_moe_permute_scales( - s=self.w2[1], - size_k=self.split_inter_size * self.pack_factor, - size_n=self.w2[1].shape[2], - group_size=self.group_size, - ) - - self.w1[2] = moe_awq_to_marlin_zero_points( - self.w1[2], - size_k=self.w1[2].shape[1], - size_n=self.w1[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.w2[2] = moe_awq_to_marlin_zero_points( - self.w2[2], - size_k=self.w2[2].shape[1], - size_n=self.w2[2].shape[2] * self.pack_factor, - num_bits=self.num_bits, - ) - - self.workspace = marlin_make_workspace_new(device, 4) - - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - if self.quantized_weight: - return cpu_tensor.cuda(device_id) - return cpu_tensor.cuda(device_id) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None From 6e227b19e22f16a7de9cf532f59683dbb0a193db Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 13:11:47 +0000 Subject: [PATCH 16/43] qk norm --- .../layer_weights/meta_weights/__init__.py | 2 +- .../fused_moe/fused_moe_weight_tp.py | 8 ---- .../layer_weights/meta_weights/norm_weight.py | 45 +++++++++++++++++-- .../basemodel}/triton_kernel/qk_norm.py | 0 .../layer_infer/transformer_layer_infer.py | 9 +--- .../layer_weights/transformer_layer_weight.py | 6 +-- .../layer_infer/transformer_layer_infer.py | 15 +++---- .../layer_infer/transformer_layer_infer.py | 7 +-- .../layer_infer/transformer_layer_infer.py | 7 +-- 9 files changed, 58 insertions(+), 41 deletions(-) rename lightllm/{models/qwen3 => common/basemodel}/triton_kernel/qk_norm.py (100%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index 47bf7c05f5..fef70acf50 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -5,7 +5,7 @@ KVROWNMMWeight, COLMMWeight, ) -from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight +from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index d30475444d..51822f7901 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -247,14 +247,6 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t num_expert_group=num_expert_group, ) - def _cuda(self, cpu_tensor): - if self.quantized_weight: - return cpu_tensor.cuda(self.device_id_) - return cpu_tensor.cuda(self.device_id_) - - def verify_load(self): - return True - def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 0232016103..73b937b776 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -4,11 +4,9 @@ from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward -from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.qk_norm import qk_rmsnorm_forward from .platform_op import PlatformAwareOp -logger = init_logger(__name__) - class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): @@ -152,3 +150,44 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) self.weight += 1 + + +class QKRMSNORMWeight(RMSNormWeight): + def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__(dim=dim, weight_name=weight_name, data_type=data_type, bias_name=bias_name) + self.tp_world_size_ = 1 + self.tp_rank_ = 0 + + def _native_forward( + self, + input: torch.Tensor, + eps: float, + ) -> None: + assert input.ndim == 2 and self.weight.ndim == 1 + assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" + head_dim = self.weight.shape[0] + x = input.to(torch.float32) + x = x.view(-1, head_dim) + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * self.weight).to(self.data_type_) + x = x.view(-1, input.shape[-1]) + input.copy_(x) + return + + def _cuda_forward( + self, + input: torch.Tensor, + eps: float, + ) -> None: + assert input.ndim == 2 and self.weight.ndim == 1 + qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps) + return + + def __call__( + self, + input: torch.Tensor, + eps: float, + ) -> None: + return self._forward(input=input, eps=eps) diff --git a/lightllm/models/qwen3/triton_kernel/qk_norm.py b/lightllm/common/basemodel/triton_kernel/qk_norm.py similarity index 100% rename from lightllm/models/qwen3/triton_kernel/qk_norm.py rename to lightllm/common/basemodel/triton_kernel/qk_norm.py diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 5f0c91287d..82331f8fb8 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -4,9 +4,6 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward -from functools import partial from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -27,14 +24,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 014f4f6ac2..7d2163f283 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight from lightllm.common.basemodel.layer_weights.meta_weights import ( - RMSNormWeight, + QKRMSNORMWeight, ) @@ -19,12 +19,12 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = RMSNormWeight( + self.q_norm_weight_ = QKRMSNORMWeight( dim=self.head_dim, weight_name=self._q_norm_name, data_type=self.data_type_, ) - self.k_norm_weight_ = RMSNormWeight( + self.k_norm_weight_ = QKRMSNORMWeight( dim=self.head_dim, weight_name=self._k_norm_name, data_type=self.data_type_, diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 5cd29dcdb7..d273d51ad5 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -60,16 +60,13 @@ def _get_qkv( ) -> Tuple[torch.Tensor, torch.Tensor]: input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - layer_weight.q_norm_weight_(q.view(-1, self.head_dim_), eps=self.eps_, out=q.view(-1, self.head_dim_)) - - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( - input=cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + cache_kv = layer_weight.kv_proj.mm(input) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, - alloc_func=self.alloc_tensor, - ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) - + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, : self.tp_k_head_num_, :], diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index d1c51365a1..d34babaabe 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -9,7 +9,6 @@ from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -30,14 +29,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 328cc0a625..4ccc6da372 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -5,7 +5,6 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward from lightllm.distributed import all_reduce from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features @@ -26,14 +25,12 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - qk_rmsnorm_forward( + layer_weight.q_norm_weight_( q, - weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, ) - qk_rmsnorm_forward( + layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) From eadcccace784a40d081a861f24b68a32d9b3cb77 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 12 Jan 2026 13:13:53 +0000 Subject: [PATCH 17/43] remove PlatformAwareOp.__init__() --- .../layer_weights/meta_weights/embedding_weight.py | 9 +++------ .../meta_weights/fused_moe/fused_moe_weight_ep.py | 4 +--- .../meta_weights/fused_moe/fused_moe_weight_tp.py | 3 +-- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index e3dc0af196..e228d5c869 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -12,7 +12,7 @@ class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.dim = dim self.vocab_size = vocab_size self.tp_world_size_ = get_dp_world_size() @@ -24,7 +24,6 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch self.weight_name: str = weight_name self.data_type_ = data_type self._create_weight() - PlatformAwareOp.__init__(self) def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id @@ -87,7 +86,7 @@ def __init__( data_type: torch.dtype, shared_weight: Optional[EmbeddingWeight] = None, ): - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.dim = dim self.vocab_size = vocab_size self.tp_world_size_ = get_dp_world_size() @@ -101,7 +100,6 @@ def __init__( self._shared_weight = shared_weight if shared_weight is None: self._create_weight() - PlatformAwareOp.__init__(self) @property def weight(self) -> torch.Tensor: @@ -156,7 +154,7 @@ def __call__(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, allo class NoTpPosEmbeddingWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, data_type: torch.dtype): - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.weight_name: str = weight_name @@ -164,7 +162,6 @@ def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, dat self.tp_world_size_ = 1 self.tp_rank_ = 0 self._create_weight() - PlatformAwareOp.__init__(self) def _create_weight(self): self.weight: torch.Tensor = torch.empty( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 9a4feccdb4..a84d198937 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -44,7 +44,7 @@ def __init__( quant_cfg=None, hidden_size: Optional[int] = None, ) -> None: - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.layer_num = layer_num self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") @@ -112,8 +112,6 @@ def __init__( if self.hidden_size is not None: self._create_weight() - PlatformAwareOp.__init__(self) - def _create_weight(self): """Pre-allocate GPU memory for fused MoE weights""" if self.hidden_size is None: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index 51822f7901..c6b3dc9656 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -73,7 +73,7 @@ def __init__( layer_num: int, quant_cfg: Quantcfg = None, ) -> None: - BaseWeightTpl.__init__(self, data_type=data_type) + super().__init__() self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method.method_name != "none": @@ -100,7 +100,6 @@ def __init__( self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) self._create_weight() - PlatformAwareOp.__init__(self) def _create_weight(self): total_expert_num = self.n_routed_experts From 7ea5e7785098282ae9266cdc766e0f32579d7284 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 13 Jan 2026 03:50:55 +0000 Subject: [PATCH 18/43] fix model call --- .../layer_weights/transformer_layer_weight.py | 12 ++-- .../layer_weights/transformer_layer_weight.py | 72 ++++++++++--------- .../layer_weights/transformer_layer_weight.py | 25 +++---- .../pre_and_post_layer_weight.py | 1 + .../layer_weights/transformer_layer_weight.py | 15 ++-- .../layer_weights/transformer_layer_weight.py | 5 +- .../layer_weights/transformer_layer_weight.py | 16 +++-- .../layer_weights/transformer_layer_weight.py | 6 +- .../layer_weights/transformer_layer_weight.py | 12 ++-- .../layer_weights/transformer_layer_weight.py | 12 ++-- .../layer_weights/transformer_layer_weight.py | 32 +++++---- 11 files changed, 114 insertions(+), 94 deletions(-) diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 599893655d..568a9c0381 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -108,18 +108,18 @@ def load_hf_weights(self, weights): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_up_weight_name, data_type=self.data_type_, bias_names=self._gate_up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 6ff081eef4..05897203ae 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -43,6 +43,13 @@ def _parse_config(self): moe_mode = os.getenv("MOE_MODE", "TP") assert moe_mode == "TP" self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + self.moe_inter = self.network_config_.get("moe_intermediate_size", self.n_inter) + self.q_out_dim = self.num_attention_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim) + self.kv_a_out_dim = self.kv_lora_rank + self.qk_rope_head_dim + self.kv_b_out_dim = self.num_attention_heads * (self.qk_nope_head_dim + self.v_head_dim) + self.o_in_dim = self.num_attention_heads * self.v_head_dim def _init_weight_names(self): if self.q_lora_rank is None: @@ -140,40 +147,40 @@ def load_hf_weights(self, weights): def _init_qkvo(self): if self.q_lora_rank is None: self.q_weight_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.q_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.q_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_weight", + quant_method=self.get_quant_method("q_weight"), ) self.kv_a_proj_with_mqa_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.kv_a_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_a_proj_with_mqa", + quant_method=self.get_quant_method("kv_a_proj_with_mqa"), tp_rank=0, tp_world_size=1, ) else: self.qkv_a_proj_with_mqa_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.q_lora_rank, self.kv_a_out_dim], weight_names=[ f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight", ], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="qkv_a_proj_with_mqa", + quant_method=self.get_quant_method("qkv_a_proj_with_mqa"), tp_rank=0, tp_world_size=1, ) self.q_b_proj_ = ROWMMWeight( + in_dim=self.q_lora_rank, + out_dims=[self.q_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_b_proj", + quant_method=self.get_quant_method("q_b_proj"), ) # self.k_b_proj_ = ROWBMMWeight( # weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", @@ -191,65 +198,66 @@ def _init_qkvo(self): # ) if self.enable_cc_method: self.cc_kv_b_proj_ = ROWMMWeight( + in_dim=self.kv_lora_rank, + out_dims=[self.kv_b_out_dim], weight_names=f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="cc_kv_b_proj", + quant_method=self.get_quant_method("cc_kv_b_proj"), ) self.o_weight_ = COLMMWeight( + in_dim=self.o_in_dim, + out_dims=[self.n_embed], weight_names=f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_weight", + quant_method=self.get_quant_method("o_weight"), ) def _load_mlp(self, mlp_prefix): moe_mode = os.getenv("MOE_MODE", "TP") if self.is_moe and moe_mode == "EP": self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.moe_inter, self.moe_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), tp_rank=0, tp_world_size=1, ) self.down_proj = COLMMWeight( + in_dim=self.moe_inter, + out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), tp_rank=0, tp_world_size=1, ) else: self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=self.get_quant_method("moe_gate"), tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index 11e3c2f36f..a4340a17a7 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -25,39 +25,40 @@ def _init_weight_names(self): def _init_ffn(self): self.gate_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_weight_name, data_type=self.data_type_, bias_names=self._gate_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_proj", + quant_method=self.get_quant_method("gate_proj"), ) self.up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._up_weight_name, data_type=self.data_type_, bias_names=self._up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="up_proj", + quant_method=self.get_quant_method("up_proj"), ) super()._init_ffn() def _init_qkv(self): + kv_out_dim = self.k_head_num_ * self.head_dim self.k_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[kv_out_dim], weight_names=self._k_weight_name, data_type=self.data_type_, bias_names=self._k_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="k_proj", + quant_method=self.get_quant_method("k_proj"), ) self.v_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[kv_out_dim], weight_names=self._v_weight_name, data_type=self.data_type_, bias_names=self._v_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="v_proj", + quant_method=self.get_quant_method("v_proj"), ) super()._init_qkv() diff --git a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py index fbfb2ee757..23ae50f096 100644 --- a/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/pre_and_post_layer_weight.py @@ -17,6 +17,7 @@ def __init__(self, data_type, network_config): self.lm_head_weight_ = self.wte_weight_ self.final_norm_weight_ = NoTpGEMMANormWeight( + dim=hidden_size, weight_name="model.norm.weight", data_type=self.data_type_, bias_name=None, diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 87b2fb744d..19951b9900 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -11,21 +11,24 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): return def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + kv_out_dim = self.k_head_num_ * self.head_dim self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], weight_names=self._q_weight_name, data_type=self.data_type_, bias_names=self._q_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="q_proj", + quant_method=self.get_quant_method("q_proj"), ) self.kv_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim, kv_out_dim], weight_names=[self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="kv_proj", + quant_method=self.get_quant_method("kv_proj"), ) def _init_norm(self): diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 5fb85aa1cd..0e7f4c8732 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -31,11 +31,12 @@ def _init_moe(self): assert moe_mode in ["TP"], "For now, GPT-OSS type model only support MOE TP mode." self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[n_routed_experts], weight_names=self._router_weight_name, data_type=self.data_type_, - layer_num=self.layer_num_, bias_names=self._router_bias_name, - name="moe_gate", + quant_method=self.get_quant_method("moe_gate"), tp_rank=0, tp_world_size=1, ) diff --git a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py index b58e58799a..2cbc6cf585 100644 --- a/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mistral_mtp/layer_weights/transformer_layer_weight.py @@ -7,6 +7,10 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__(layer_num, data_type, network_config, quant_cfg) return + def _parse_config(self): + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + def _init_weight_names(self): self._gate_weight_name = f"mtp.layers.{self.layer_num_}.mlp.gate_proj.weight" self._gate_bias_name = None @@ -24,20 +28,20 @@ def _init_weight(self): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=[self._gate_bias_name, self._up_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) def _init_norm(self): diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index cc125f926d..fa20a63f9c 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -34,12 +34,12 @@ def _init_moe(self): split_inter_size = inter_size // self.tp_world_size_ self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_routed_experts], weight_names=self.moe_gate_weight_name, data_type=self.data_type_, bias_names=self.moe_gate_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="moe_gate", + quant_method=self.get_quant_method("moe_gate"), tp_rank=0, tp_world_size=1, # no tensor parallelism ) diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 41f24f79cb..4adbe4f5e6 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -51,18 +51,18 @@ def _init_weight_names(self): def _init_ffn(self): self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._gate_up_weight_name, data_type=self.data_type_, bias_names=self._gate_up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="gate_up_proj", + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index 53342e221f..7370c69530 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -28,18 +28,18 @@ def _init_weight_names(self): def _init_ffn(self): self.up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self._up_weight_name, data_type=self.data_type_, bias_names=self._up_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="up_proj", + quant_method=self.get_quant_method("up_proj"), ) self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self._down_weight_name, data_type=self.data_type_, bias_names=self._down_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="down_proj", + quant_method=self.get_quant_method("down_proj"), ) diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 5a7a24a9a8..8bcbe3358e 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -31,6 +31,9 @@ def _parse_config(self): self.qkv_bias = self.network_config_.get("qkv_bias", True) self.layer_norm_eps = self.network_config_.get("layer_norm_eps", 1e-6) self.norm_type = self.network_config_.get("norm_type", "layer_norm") + self.n_embed = self.network_config_["hidden_size"] + self.padding_hidden_size + mlp_ratio = self.network_config_.get("mlp_ratio", 4) + self.n_inter = self.network_config_.get("intermediate_size", int(self.n_embed * mlp_ratio)) def _init_weight_names(self): self._att_norm_weight_name = f"vision_model.encoder.layers.{self.layer_num_}.norm1.weight" @@ -83,41 +86,41 @@ def _init_weight(self): def _init_qkv(self): self.qkv_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_embed, self.n_embed, self.n_embed], weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], data_type=self.data_type_, bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="qkv_proj", + quant_method=self.get_quant_method("qkv_proj"), ) def _init_o(self): self.o_proj = COLMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_embed], weight_names=self._o_weight_name, data_type=self.data_type_, bias_names=self._o_bias_name, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="o_proj", + quant_method=self.get_quant_method("o_proj"), ) def _init_ffn(self): self.ffn_1_proj_ = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], weight_names=self.fc1_weight_name_, data_type=self.data_type_, bias_names=self.fc1_bias_name_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="ffn_1_proj", + quant_method=self.get_quant_method("ffn_1_proj"), ) self.ffn_2_proj_ = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], weight_names=self.fc2_weight_name_, data_type=self.data_type_, bias_names=self.fc2_bias_name_, - quant_cfg=self.quant_cfg, - layer_num=self.layer_num_, - name="ffn_2_proj", + quant_method=self.get_quant_method("ffn_2_proj"), ) def _init_norm(self): @@ -136,18 +139,17 @@ def _init_norm(self): bias_name=self._ffn_norm_bias_name, ) if self.qk_norm: - head_num = self.network_config_["num_attention_heads"] self.q_norm_weight_ = TpRMSNormWeight( dim=hidden_size, weight_name=self._q_norm_weight_name, data_type=self.data_type_, - head_num=head_num, + bias_name=None, ) self.k_norm_weight_ = TpRMSNormWeight( dim=hidden_size, weight_name=self._k_norm_weight_name, data_type=self.data_type_, - head_num=head_num, + bias_name=None, ) def load_hf_weights(self, weights): From 057f742767207918c1cdc27b1654c75cbacc649f Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 13 Jan 2026 05:28:30 +0000 Subject: [PATCH 19/43] remove torchao --- .../CN/source/tutorial/api_server_args_zh.rst | 19 +- .../EN/source/tutorial/api_server_args_zh.rst | 19 +- lightllm/common/quantization/__init__.py | 1 - lightllm/common/quantization/torchao_quant.py | 168 ------------------ lightllm/server/api_cli.py | 10 +- 5 files changed, 13 insertions(+), 204 deletions(-) delete mode 100644 lightllm/common/quantization/torchao_quant.py diff --git a/docs/CN/source/tutorial/api_server_args_zh.rst b/docs/CN/source/tutorial/api_server_args_zh.rst index 5976fcb322..3503ea9ca4 100755 --- a/docs/CN/source/tutorial/api_server_args_zh.rst +++ b/docs/CN/source/tutorial/api_server_args_zh.rst @@ -349,17 +349,14 @@ PD 分离模式参数 .. option:: --quant_type 量化方法,可选值: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``vllm-fp8w8a8-b128`` + * ``deepgemm-fp8w8a8-b128`` * ``triton-fp8w8a8-block128`` + * ``awq`` + * ``awq_marlin`` * ``none`` (默认) .. option:: --quant_cfg @@ -371,13 +368,7 @@ PD 分离模式参数 .. option:: --vit_quant_type ViT 量化方法,可选值: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``none`` (默认) diff --git a/docs/EN/source/tutorial/api_server_args_zh.rst b/docs/EN/source/tutorial/api_server_args_zh.rst index 0767ae7e3b..42555ff9ec 100755 --- a/docs/EN/source/tutorial/api_server_args_zh.rst +++ b/docs/EN/source/tutorial/api_server_args_zh.rst @@ -341,17 +341,14 @@ Quantization Parameters .. option:: --quant_type Quantization method, optional values: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``vllm-fp8w8a8-b128`` + * ``deepgemm-fp8w8a8-b128`` * ``triton-fp8w8a8-block128`` + * ``awq`` + * ``awq_marlin`` * ``none`` (default) .. option:: --quant_cfg @@ -363,13 +360,7 @@ Quantization Parameters .. option:: --vit_quant_type ViT quantization method, optional values: - - * ``ppl-w4a16-128`` - * ``flashllm-w6a16`` - * ``ao-int4wo-[32,64,128,256]`` - * ``ao-int8wo`` - * ``ao-fp8w8a16`` - * ``ao-fp6w6a16`` + * ``vllm-w8a8`` * ``vllm-fp8w8a8`` * ``none`` (default) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index ecf2e6d42f..af1327cd89 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -1,7 +1,6 @@ import yaml import collections from .registry import QUANTMETHODS -from .torchao_quant import * from .w8a8_quant import * from .triton_quant.triton_quant import * from .deepgemm_quant import * diff --git a/lightllm/common/quantization/torchao_quant.py b/lightllm/common/quantization/torchao_quant.py deleted file mode 100644 index d1db65b35a..0000000000 --- a/lightllm/common/quantization/torchao_quant.py +++ /dev/null @@ -1,168 +0,0 @@ -import os -import torch -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from typing import TYPE_CHECKING, Optional - -from .quantize_method import WeightPack - -try: - HAS_TORCH_AO = True - from torchao.quantization import ( - int4_weight_only, - int8_weight_only, - float8_weight_only, - fpx_weight_only, - int8_dynamic_activation_int8_weight, - float8_dynamic_activation_float8_weight, - quantize_, - ) - from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - ) -except: - HAS_TORCH_AO = False - - -class AOBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - assert HAS_TORCH_AO, "torchao is not installed, you can't use quant api of it" - assert TORCH_VERSION_AT_LEAST_2_4, "torchao requires torch >=2.4" - self.quant_func = None - - def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: - """ """ - dummy_linear = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - dummy_linear.weight = torch.nn.Parameter(weight.cuda(self.device_id_)) - quantize_(dummy_linear, self.quant_func) - return WeightPack(weight=dummy_linear.weight, weight_scale=None, weight_zero_point=None) - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - ) -> torch.Tensor: - weights = weight_pack.weight - bias = weight_pack.bias - return F.linear(input_tensor, weights, bias) - - @property - def method_name(self): - return "ao-base" - - -@QUANTMETHODS.register(["ao-w4a16-256"]) -class AOW4A16QuantizationMethodGroup256(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 256 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-256" - - -@QUANTMETHODS.register(["ao-w4a16-128"]) -class AOW4A16QuantizationMethodGroup128(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 128 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-128" - - -@QUANTMETHODS.register(["ao-w4a16-64"]) -class AOW4A16QuantizationMethodGroup64(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 64 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-64" - - -@QUANTMETHODS.register(["ao-w4a16-32"]) -class AOW4A16QuantizationMethodGroup32(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.group_size = 32 - self.quant_func = int4_weight_only(group_size=self.group_size) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w4a16-32" - - -@QUANTMETHODS.register("ao-w8a8") -class AOW8A8QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.quant_func = int8_dynamic_activation_int8_weight() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w8a8" - - -@QUANTMETHODS.register("ao-w8a16") -class AOW8A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.quant_func = int8_weight_only() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-w8a16" - - -@QUANTMETHODS.register("ao-fp8w8a16") -class AOFP8W8A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) - assert is_cuda_8_9, "FP8 requires GPU with compute capability >= 8.9" - self.quant_func = float8_weight_only() - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-fp8w8a16" - - -@QUANTMETHODS.register("ao-fp6w6a16") -class AOFP6W6A16QuantizationMethod(AOBaseQuantizationMethod): - def __init__(self): - super().__init__() - assert TORCH_VERSION_AT_LEAST_2_5, "torchao fp6 requires torch >=2.5" - self.quant_func = fpx_weight_only(3, 2) - self.has_weight_scale = False - self.has_weight_zero_point = False - - @property - def method_name(self): - return "ao-fp6w6a16" diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 9cbc59032c..821f94236b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -465,10 +465,8 @@ def make_argument_parser() -> argparse.ArgumentParser: "--quant_type", type=str, default="none", - help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16 - | ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16 - | vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 - | triton-fp8w8a8-block128""", + help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 + | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin""", ) parser.add_argument( "--quant_cfg", @@ -481,9 +479,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--vit_quant_type", type=str, default="none", - help="""Quantization method: ppl-w4a16-128 | flashllm-w6a16 - | ao-int4wo-[32,64,128,256] | ao-int8wo | ao-fp8w8a16 | ao-fp6w6a16 - | vllm-w8a8 | vllm-fp8w8a8""", + help="""Quantization method for ViT: vllm-w8a8 | vllm-fp8w8a8""", ) parser.add_argument( "--vit_quant_cfg", From cee6e238b519f5ca2035c17eb596feb9e0a2a543 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Tue, 13 Jan 2026 13:58:15 +0800 Subject: [PATCH 20/43] [MUSA] Add shell script to generate requirements-musa.txt and update doc (#1175) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Testing Done Tested in a clean docker container without vllm installed. ```bash root@worker3218:/ws# python -m lightllm.server.api_server --model_dir /home/dist/Qwen3-0.6B/ --disable_cudagraph --host 0.0.0.0 WARNING 01-12 13:45:20 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:45:20 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:45:20 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:45:20 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-12 13:45:20 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-12 13:45:20 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-12 13:45:20 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm WARNING 01-12 13:45:20 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! INFO 01-12 13:45:21 [shm_size_check.py:21] SHM check: Available=500.00 GB,Recommended=2.32 GB.Sufficient: True INFO 01-12 13:45:21 [api_start.py:94] zmq mode head: ipc:///tmp/_28765_0_ INFO 01-12 13:45:21 [api_start.py:96] use tgi api: False INFO 01-12 13:45:21 [api_start.py:219] alloced ports: [10017, 10004, 10209, 10223, 10297, 10257, 10068, 10179, 10206, 10285] INFO 01-12 13:45:21 [api_start.py:270] all start args:Namespace(run_mode='normal', host='0.0.0.0', port=8000, httpserver_workers=1, zmq_mode='ipc:///tmp/_28765_0_', pd_master_ip='0.0.0.0', pd_master_port=1212, pd_decode_rpyc_port=42000, select_p_d_node_strategy='round_robin', config_server_host=None, config_server_port=None, nixl_pd_kv_page_num=16, nixl_pd_kv_page_size=1024, model_name='default_model_name', model_dir='/home/dist/Qwen3-0.6B/', tokenizer_mode='fast', load_way='HF', max_total_token_num=None, mem_fraction=0.9, batch_max_tokens=8448, eos_id=[151645], tool_call_parser=None, reasoning_parser=None, chat_template=None, running_max_req_size=1000, nnodes=1, node_rank=0, multinode_httpmanager_port=12345, multinode_router_gloo_port=20001, tp=1, dp=1, dp_balancer='bs_balancer', max_req_total_len=16384, nccl_host='127.0.0.1', nccl_port=28765, use_config_server_to_init_nccl=False, trust_remote_code=False, disable_log_stats=False, log_stats_interval=10, disable_shm_warning=False, router_token_ratio=0.0, router_max_new_token_len=1024, router_max_wait_tokens=1, disable_aggressive_schedule=False, use_dynamic_prompt_cache=False, disable_dynamic_prompt_cache=False, chunked_prefill_size=4096, disable_chunked_prefill=False, diverse_mode=False, token_healing_mode=False, output_constraint_mode='none', first_token_constraint_mode=False, enable_multimodal=False, enable_multimodal_audio=False, enable_mps=False, disable_custom_allreduce=False, enable_custom_allgather=False, enable_tpsp_mix_mode=False, enable_dp_prefill_balance=False, enable_prefill_microbatch_overlap=False, enable_decode_microbatch_overlap=False, llm_prefill_att_backend=['triton'], llm_decode_att_backend=['triton'], llm_kv_type='None', llm_kv_quant_group_size=8, cache_capacity=200, embed_cache_storage_size=4, data_type='bfloat16', return_all_prompt_logprobs=False, use_reward_model=False, long_truncation_mode=None, use_tgi_api=False, health_monitor=False, metric_gateway=None, job_name='lightllm', grouping_key=[], push_interval=10, visual_infer_batch_size=1, visual_send_batch_size=1, visual_gpu_ids=[0], visual_tp=1, visual_dp=1, visual_nccl_ports=[29500], enable_monitor_auth=False, disable_cudagraph=True, enable_prefill_cudagraph=False, prefll_cudagraph_max_handle_token=512, graph_max_batch_size=256, graph_split_batch_size=32, graph_grow_step_size=16, graph_max_len_in_batch=16384, quant_type='none', quant_cfg=None, vit_quant_type='none', vit_quant_cfg=None, sampling_backend='triton', penalty_counter_mode='gpu_counter', ep_redundancy_expert_config_path=None, auto_update_redundancy_expert=False, enable_fused_shared_experts=False, mtp_mode=None, mtp_draft_model_dir=None, mtp_step=0, kv_quant_calibration_config_path=None, schedule_time_interval=0.03, enable_cpu_cache=False, cpu_cache_storage_size=2, cpu_cache_token_page_size=256, enable_disk_cache=False, disk_cache_storage_size=10, disk_cache_dir=None, enable_dp_prompt_cache_fetch=False, router_port=10017, detokenization_port=10004, http_server_port=10209, visual_port=10223, audio_port=10297, cache_port=10257, metric_port=10068, multi_level_kv_cache_port=10179, pd_node_infer_rpyc_ports=[10285], pd_node_id=288479957063433772586255832729030629155, pd_p_allowed_port_min=20000, pd_p_allowed_port_max=30000) WARNING 01-12 13:45:27 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:45:27 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:45:27 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:45:27 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-12 13:45:27 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. 2026-01-12 13:45:27 | server | 140078322902144 | INFO : server started on [0.0.0.0]:10068 INFO 01-12 13:45:27 [start_utils.py:37] init func start_metric_manager : init ok WARNING 01-12 13:45:33 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:45:33 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:45:33 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:45:33 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-12 13:45:33 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-12 13:45:33 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-12 13:45:33 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:45:33 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:45:33 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:45:33 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-12 13:45:33 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-12 13:45:33 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-12 13:45:33 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm WARNING 01-12 13:45:33 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm INFO 01-12 13:45:33 [manager.py:36] pub_to_httpserver sendhwm 1000 WARNING 01-12 13:45:33 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! 2026-01-12 13:45:33 | server | 140078322902144 | INFO : accepted ('127.0.0.1', 47548) with fd 25 2026-01-12 13:45:33 | server | 140046992746048 | INFO : welcome ('127.0.0.1', 47548) INFO 01-12 13:45:38 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-12 13:45:38 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:45:38 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:45:38 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. WARNING 01-12 13:45:38 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:45:38 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm INFO 01-12 13:45:38 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. WARNING 01-12 13:45:40 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!! INFO 01-12 13:45:40 [model_rpc.py:67] Initialized RPC server for rank 0. INFO 01-12 13:45:40 [model_rpc.py:168] use ChunkedPrefillBackend INFO 01-12 13:45:43 [basemodel.py:169] Initial quantization. The default quantization method is none pid 45988 Loading model weights with 1 workers: 0%| | 0/1 [00:00 INFO 01-12 13:45:43 [mem_manager.py:99] 69.76169700622559 GB space is available after load the model weight INFO 01-12 13:45:43 [mem_manager.py:99] 0.109375 MB is the size of one token kv cache INFO 01-12 13:45:43 [mem_manager.py:99] 653128 is the profiled max_total_token_num with the mem_fraction 0.9 INFO 01-12 13:45:43 [mem_manager.py:99] INFO 01-12 13:45:44 [basemodel.py:126] use prefill att backend: TritonAttBackend INFO 01-12 13:45:44 [basemodel.py:127] use decode att backend: TritonAttBackend warming up: 0%| | 0/12 [00:00 INFO 01-12 13:46:40 [manager.py:194] use req queue ChunkedPrefillQueue INFO 01-12 13:46:40 [start_utils.py:37] init func start_router_process : init ok INFO 01-12 13:46:40 [start_utils.py:37] init func start_detokenization_process : init ok INFO 01-12 13:46:40 [api_start.py:58] start process pid 38328 INFO 01-12 13:46:40 [api_start.py:59] http server pid 5689 [2026-01-12 13:46:40 +0800] [5689] [INFO] Starting gunicorn 23.0.0 [2026-01-12 13:46:40 +0800] [5689] [INFO] Listening at: http://0.0.0.0:8000 (5689) [2026-01-12 13:46:40 +0800] [5689] [INFO] Using worker: uvicorn.workers.UvicornWorker [2026-01-12 13:46:40 +0800] [5690] [INFO] Booting worker with pid: 5690 WARNING 01-12 13:46:46 [sgl_utils.py:14] sgl_kernel is not installed, you can't use the api of it. You can solve it by running `pip install sgl_kernel`. WARNING 01-12 13:46:46 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3. Try to upgrade it. WARNING 01-12 13:46:46 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it. WARNING 01-12 13:46:46 [vllm_utils.py:18] vllm is not installed, you can't use the api of it. You can solve it by running `pip install vllm`. INFO 01-12 13:46:46 [communication_op.py:57] deep_ep is not installed, you can't use the api of it. INFO 01-12 13:46:46 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On WARNING 01-12 13:46:46 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm [2026-01-12 13:46:47 +0800] [5690] [INFO] Started server process [5690] [2026-01-12 13:46:47 +0800] [5690] [INFO] Waiting for application startup. INFO 01-12 13:46:47 [api_http.py:359] server start up 2026-01-12 13:46:47 | server | 140078322902144 | INFO : accepted ('127.0.0.1', 35962) with fd 26 2026-01-12 13:46:47 | server | 140046984353344 | INFO : welcome ('127.0.0.1', 35962) 2026-01-12 13:46:47 | server | 140078322902144 | INFO : accepted ('127.0.0.1', 35966) with fd 27 2026-01-12 13:46:47 | server | 140046975960640 | INFO : welcome ('127.0.0.1', 35966) INFO 01-12 13:46:48 [req_id_generator.py:34] ReqIDGenerator init finished INFO 01-12 13:46:48 [api_http.py:363] server start up ok, loop use is [2026-01-12 13:46:48 +0800] [5690] [INFO] Application startup complete. DEBUG 01-12 13:47:52 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:47:52 [manager.py:283] DEBUG 01-12 13:47:52 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:47:52 [manager.py:284] [2026-01-12 13:48:13 +0800] [5689] [INFO] Handling signal: winch [2026-01-12 13:48:13 +0800] [5689] [INFO] Handling signal: winch [2026-01-12 13:48:13 +0800] [5689] [INFO] Handling signal: winch [2026-01-12 13:48:13 +0800] [5689] [INFO] Handling signal: winch DEBUG 01-12 13:48:55 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:48:55 [manager.py:283] DEBUG 01-12 13:48:55 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:48:55 [manager.py:284] DEBUG 01-12 13:49:58 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:49:58 [manager.py:283] DEBUG 01-12 13:49:58 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:49:58 [manager.py:284] DEBUG 01-12 13:51:02 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:02 [manager.py:283] DEBUG 01-12 13:51:02 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:51:02 [manager.py:284] INFO 01-12 13:51:09 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:09 lightllm_req_id:8 INFO 01-12 13:51:09 [manager.py:422] router recive req id 8 cost time 0.05662369728088379 s DEBUG 01-12 13:51:09 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197069.7485027s req_ids:[8] DEBUG 01-12 13:51:09 [manager.py:320] INFO 01-12 13:51:09 [manager.py:55] detokenization recv req id 8 cost time 0.07959198951721191 s DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 estimated_peak_token_count: 39 DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 token used ratio: 6.12437378278071e-06 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:11 [manager.py:251] dp_i 0 token used ratio: 6.12437378278071e-06 contain prompt cache tree unrefed token DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 estimated_peak_token_count: 39 DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 token used ratio: 7.655467228475888e-06 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:14 [manager.py:251] dp_i 0 token used ratio: 7.655467228475888e-06 contain prompt cache tree unrefed token INFO 01-12 13:51:16 [manager.py:163] detoken release req id 8 INFO 01-12 13:51:16 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:09 lightllm_req_id:8 first_token_cost:6353.325128555298ms total_cost_time:6671.096563339233ms,out_token_counter:17 mean_per_token_cost_time: 18.692437340231503ms prompt_token_num:4 gpu cache hit: False gpu_prompt_cache_len:0 gpu_prompt_cache_ratio:0.0 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:55472 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 13:51:16 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 13:51:16 [infer_batch.py:172] free a batch state: DEBUG 01-12 13:51:16 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 13:51:16 [infer_batch.py:172] radix hold token num 21 DEBUG 01-12 13:51:16 [infer_batch.py:172] mem manager can alloc token num 653107 DEBUG 01-12 13:51:16 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 13:51:16 [batch.py:56] router release req id 8 INFO 01-12 13:51:16 [shm_req_manager.py:111] all shm req has been release ok INFO 01-12 13:51:19 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:19 lightllm_req_id:16 INFO 01-12 13:51:19 [manager.py:422] router recive req id 16 cost time 0.019651412963867188 s DEBUG 01-12 13:51:19 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197079.421846s req_ids:[16] DEBUG 01-12 13:51:19 [manager.py:320] INFO 01-12 13:51:19 [manager.py:55] detokenization recv req id 16 cost time 0.021979331970214844 s INFO 01-12 13:51:19 [manager.py:163] detoken release req id 16 INFO 01-12 13:51:19 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:19 lightllm_req_id:16 first_token_cost:102.96440124511719ms total_cost_time:407.08088874816895ms,out_token_counter:17 mean_per_token_cost_time: 17.88920514723834ms prompt_token_num:4 gpu cache hit: True gpu_prompt_cache_len:3 gpu_prompt_cache_ratio:0.75 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:47146 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 13:51:19 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 13:51:19 [infer_batch.py:172] free a batch state: DEBUG 01-12 13:51:19 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 13:51:19 [infer_batch.py:172] radix hold token num 35 DEBUG 01-12 13:51:19 [infer_batch.py:172] mem manager can alloc token num 653093 DEBUG 01-12 13:51:19 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 13:51:19 [batch.py:56] router release req id 16 INFO 01-12 13:51:19 [shm_req_manager.py:111] all shm req has been release ok INFO 01-12 13:51:22 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:22 lightllm_req_id:24 INFO 01-12 13:51:22 [manager.py:422] router recive req id 24 cost time 0.015377998352050781 s DEBUG 01-12 13:51:22 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197082.1040523s req_ids:[24] DEBUG 01-12 13:51:22 [manager.py:320] INFO 01-12 13:51:22 [manager.py:55] detokenization recv req id 24 cost time 0.016767501831054688 s INFO 01-12 13:51:22 [manager.py:163] detoken release req id 24 INFO 01-12 13:51:22 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:22 lightllm_req_id:24 first_token_cost:86.02452278137207ms total_cost_time:432.842493057251ms,out_token_counter:17 mean_per_token_cost_time: 20.4010570750517ms prompt_token_num:4 gpu cache hit: True gpu_prompt_cache_len:3 gpu_prompt_cache_ratio:0.75 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:47156 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 13:51:22 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 13:51:22 [infer_batch.py:172] free a batch state: DEBUG 01-12 13:51:22 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 13:51:22 [infer_batch.py:172] radix hold token num 51 DEBUG 01-12 13:51:22 [infer_batch.py:172] mem manager can alloc token num 653077 DEBUG 01-12 13:51:22 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 13:51:22 [batch.py:56] router release req id 24 INFO 01-12 13:51:22 [shm_req_manager.py:111] all shm req has been release ok INFO 01-12 13:51:26 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:26 lightllm_req_id:32 INFO 01-12 13:51:26 [manager.py:422] router recive req id 32 cost time 0.008630990982055664 s DEBUG 01-12 13:51:26 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197086.9206343s req_ids:[32] DEBUG 01-12 13:51:26 [manager.py:320] INFO 01-12 13:51:26 [manager.py:55] detokenization recv req id 32 cost time 0.011269092559814453 s INFO 01-12 13:51:27 [manager.py:163] detoken release req id 32 INFO 01-12 13:51:27 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:26 lightllm_req_id:32 first_token_cost:74.12481307983398ms total_cost_time:378.31759452819824ms,out_token_counter:17 mean_per_token_cost_time: 17.89369302637437ms prompt_token_num:4 gpu cache hit: True gpu_prompt_cache_len:3 gpu_prompt_cache_ratio:0.75 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:47160 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 13:51:27 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 13:51:27 [infer_batch.py:172] free a batch state: DEBUG 01-12 13:51:27 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 13:51:27 [infer_batch.py:172] radix hold token num 68 DEBUG 01-12 13:51:27 [infer_batch.py:172] mem manager can alloc token num 653060 DEBUG 01-12 13:51:27 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 13:51:27 [batch.py:56] router release req id 32 INFO 01-12 13:51:27 [shm_req_manager.py:111] all shm req has been release ok INFO 01-12 13:51:44 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:44 lightllm_req_id:40 INFO 01-12 13:51:44 [manager.py:422] router recive req id 40 cost time 0.009232759475708008 s DEBUG 01-12 13:51:44 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197104.2886696s req_ids:[40] DEBUG 01-12 13:51:44 [manager.py:320] INFO 01-12 13:51:44 [manager.py:55] detokenization recv req id 40 cost time 0.010197639465332031 s DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 estimated_peak_token_count: 2022 DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 token used ratio: 0.00019597996104898273 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:47 [manager.py:251] dp_i 0 token used ratio: 0.0002955010350191693 contain prompt cache tree unrefed token DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 estimated_peak_token_count: 2022 DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 token used ratio: 0.0002618169792138754 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:50 [manager.py:251] dp_i 0 token used ratio: 0.0003613380531840619 contain prompt cache tree unrefed token DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 token used ratio: 0.0005052608370794086 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:53 [manager.py:251] dp_i 0 token used ratio: 0.0006047819110495952 contain prompt cache tree unrefed token DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 token used ratio: 0.0007456425080535515 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:56 [manager.py:251] dp_i 0 token used ratio: 0.000845163582023738 contain prompt cache tree unrefed token DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 token used ratio: 0.0009875552724733895 not contain prompt cache tree unrefed token DEBUG 01-12 13:51:59 [manager.py:251] dp_i 0 token used ratio: 0.001087076346443576 contain prompt cache tree unrefed token DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 token used ratio: 0.0012264058500018372 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:02 [manager.py:251] dp_i 0 token used ratio: 0.001325926923972024 contain prompt cache tree unrefed token DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 token used ratio: 0.0014086059700395635 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:05 [manager.py:251] dp_i 0 token used ratio: 0.00150812704400975 contain prompt cache tree unrefed token DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 token used ratio: 0.0015724329687289474 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:08 [manager.py:251] dp_i 0 token used ratio: 0.001671954042699134 contain prompt cache tree unrefed token DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 token used ratio: 0.0017331977805269412 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:11 [manager.py:251] dp_i 0 token used ratio: 0.0018327188544971277 contain prompt cache tree unrefed token DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 token used ratio: 0.0018939625923249349 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:14 [manager.py:251] dp_i 0 token used ratio: 0.0019934836662951214 contain prompt cache tree unrefed token DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 token used ratio: 0.0020531963106772333 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:17 [manager.py:251] dp_i 0 token used ratio: 0.00215271738464742 contain prompt cache tree unrefed token DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 token used ratio: 0.002213961122475227 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:20 [manager.py:251] dp_i 0 token used ratio: 0.0023134821964454133 contain prompt cache tree unrefed token DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 token used ratio: 0.0023731948408275256 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:23 [manager.py:251] dp_i 0 token used ratio: 0.002472715914797712 contain prompt cache tree unrefed token DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 token used ratio: 0.002509462157494396 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:26 [manager.py:251] dp_i 0 token used ratio: 0.002608983231464583 contain prompt cache tree unrefed token DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 token used ratio: 0.0026288874462586202 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:29 [manager.py:251] dp_i 0 token used ratio: 0.0027284085202288065 contain prompt cache tree unrefed token DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 token used ratio: 0.002746781641577149 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:32 [manager.py:251] dp_i 0 token used ratio: 0.002846302715547335 contain prompt cache tree unrefed token DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 token used ratio: 0.002861613650004287 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:35 [manager.py:251] dp_i 0 token used ratio: 0.0029611347239744735 contain prompt cache tree unrefed token DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 token used ratio: 0.002939699415734741 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:38 [manager.py:251] dp_i 0 token used ratio: 0.0030392204897049277 contain prompt cache tree unrefed token DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 estimated_peak_token_count: 2020 DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 token used ratio: 0.0030116608076824146 not contain prompt cache tree unrefed token DEBUG 01-12 13:52:41 [manager.py:251] dp_i 0 token used ratio: 0.003111181881652601 contain prompt cache tree unrefed token INFO 01-12 13:52:42 [manager.py:163] detoken release req id 40 INFO 01-12 13:52:42 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 13:51:44 lightllm_req_id:40 first_token_cost:91.23969078063965ms total_cost_time:58654.03771400452ms,out_token_counter:2000 mean_per_token_cost_time: 29.28139901161194ms prompt_token_num:4 gpu cache hit: True gpu_prompt_cache_len:3 gpu_prompt_cache_ratio:0.75 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:50156 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 13:52:42 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 13:52:42 [infer_batch.py:172] free a batch state: DEBUG 01-12 13:52:42 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 13:52:42 [infer_batch.py:172] radix hold token num 2068 DEBUG 01-12 13:52:42 [infer_batch.py:172] mem manager can alloc token num 651060 DEBUG 01-12 13:52:42 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 13:52:42 [batch.py:56] router release req id 40 INFO 01-12 13:52:42 [shm_req_manager.py:111] all shm req has been release ok DEBUG 01-12 13:52:50 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:52:50 [manager.py:283] DEBUG 01-12 13:52:50 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:52:50 [manager.py:284] DEBUG 01-12 13:53:53 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:53:53 [manager.py:283] DEBUG 01-12 13:53:53 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:53:53 [manager.py:284] DEBUG 01-12 13:54:56 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:54:56 [manager.py:283] DEBUG 01-12 13:54:56 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:54:56 [manager.py:284] DEBUG 01-12 13:56:00 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:56:00 [manager.py:283] DEBUG 01-12 13:56:00 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:56:00 [manager.py:284] DEBUG 01-12 13:57:03 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:57:03 [manager.py:283] DEBUG 01-12 13:57:03 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:57:03 [manager.py:284] DEBUG 01-12 13:58:06 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:58:06 [manager.py:283] DEBUG 01-12 13:58:06 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:58:06 [manager.py:284] DEBUG 01-12 13:59:09 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 13:59:09 [manager.py:283] DEBUG 01-12 13:59:09 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 13:59:09 [manager.py:284] INFO 01-12 14:00:06 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-12 14:00:06 lightllm_req_id:48 INFO 01-12 14:00:06 [manager.py:422] router recive req id 48 cost time 0.00828862190246582 s DEBUG 01-12 14:00:06 [manager.py:320] Prefill Batch: batch_id=-1, time:1768197606.2045314s req_ids:[48] DEBUG 01-12 14:00:06 [manager.py:320] INFO 01-12 14:00:06 [manager.py:55] detokenization recv req id 48 cost time 0.010654926300048828 s DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 estimated_peak_token_count: 222 DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 token used ratio: 4.746389681655051e-05 not contain prompt cache tree unrefed token DEBUG 01-12 14:00:06 [manager.py:251] dp_i 0 token used ratio: 0.0032091718621770926 contain prompt cache tree unrefed token DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 current batch size: 1 DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 paused req num: 0 DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 frozen token num: 0 DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 estimated_peak_token_count: 222 DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 token used ratio: 0.0002878455677906934 not contain prompt cache tree unrefed token DEBUG 01-12 14:00:09 [manager.py:251] dp_i 0 token used ratio: 0.003449553533151235 contain prompt cache tree unrefed token INFO 01-12 14:00:10 [manager.py:163] detoken release req id 48 INFO 01-12 14:00:10 [manager.py:614] X-Request-Id: X-Session-Id: start_time:2026-01-12 14:00:06 lightllm_req_id:48 first_token_cost:94.14434432983398ms total_cost_time:3917.818784713745ms,out_token_counter:200 mean_per_token_cost_time: 19.118372201919556ms prompt_token_num:4 gpu cache hit: True gpu_prompt_cache_len:3 gpu_prompt_cache_ratio:0.75 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 127.0.0.1:53836 - "POST /generate HTTP/1.1" 200 DEBUG 01-12 14:00:10 [req_manager.py:78] freed all request size 1008 DEBUG 01-12 14:00:10 [infer_batch.py:172] free a batch state: DEBUG 01-12 14:00:10 [infer_batch.py:172] radix refed token num 0 DEBUG 01-12 14:00:10 [infer_batch.py:172] radix hold token num 2266 DEBUG 01-12 14:00:10 [infer_batch.py:172] mem manager can alloc token num 650862 DEBUG 01-12 14:00:10 [infer_batch.py:172] mem manager total size 653128 INFO 01-12 14:00:10 [batch.py:56] router release req id 48 INFO 01-12 14:00:10 [shm_req_manager.py:111] all shm req has been release ok DEBUG 01-12 14:00:12 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:00:12 [manager.py:283] DEBUG 01-12 14:00:12 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:00:12 [manager.py:284] DEBUG 01-12 14:01:16 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:01:16 [manager.py:283] DEBUG 01-12 14:01:16 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:01:16 [manager.py:284] DEBUG 01-12 14:02:19 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:02:19 [manager.py:283] DEBUG 01-12 14:02:19 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:02:19 [manager.py:284] [2026-01-12 14:03:16 +0800] [5689] [INFO] Handling signal: winch DEBUG 01-12 14:03:22 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:03:22 [manager.py:283] DEBUG 01-12 14:03:22 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:03:22 [manager.py:284] DEBUG 01-12 14:04:25 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:04:25 [manager.py:283] DEBUG 01-12 14:04:25 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:04:25 [manager.py:284] DEBUG 01-12 14:05:28 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:05:28 [manager.py:283] DEBUG 01-12 14:05:28 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:05:28 [manager.py:284] [2026-01-12 14:06:28 +0800] [5689] [INFO] Handling signal: winch [2026-01-12 14:06:28 +0800] [5689] [INFO] Handling signal: winch [2026-01-12 14:06:28 +0800] [5689] [INFO] Handling signal: winch DEBUG 01-12 14:06:31 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:06:31 [manager.py:283] DEBUG 01-12 14:06:31 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:06:31 [manager.py:284] DEBUG 01-12 14:07:35 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:07:35 [manager.py:283] DEBUG 01-12 14:07:35 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:07:35 [manager.py:284] DEBUG 01-12 14:08:38 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:08:38 [manager.py:283] DEBUG 01-12 14:08:38 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:08:38 [manager.py:284] DEBUG 01-12 14:09:41 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:09:41 [manager.py:283] DEBUG 01-12 14:09:41 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:09:41 [manager.py:284] DEBUG 01-12 14:10:44 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:10:44 [manager.py:283] DEBUG 01-12 14:10:44 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:10:44 [manager.py:284] DEBUG 01-12 14:11:47 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:11:47 [manager.py:283] DEBUG 01-12 14:11:47 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:11:47 [manager.py:284] [2026-01-12 14:11:57 +0800] [5689] [INFO] Handling signal: winch DEBUG 01-12 14:12:51 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:12:51 [manager.py:283] DEBUG 01-12 14:12:51 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:12:51 [manager.py:284] DEBUG 01-12 14:13:54 [manager.py:283] dp_i 0 frozen token num: 0 DEBUG 01-12 14:13:54 [manager.py:283] DEBUG 01-12 14:13:54 [manager.py:284] dp_i 0 estimated_peak_token_count: 0 DEBUG 01-12 14:13:54 [manager.py:284] ``` Signed-off-by: Xiaodong Ye --- .gitignore | 1 + .../source/getting_started/installation.rst | 12 +- .../source/getting_started/installation.rst | 22 ++-- generate_requirements_musa.sh | 105 ++++++++++++++++++ 4 files changed, 127 insertions(+), 13 deletions(-) create mode 100755 generate_requirements_musa.sh diff --git a/.gitignore b/.gitignore index 6049c2cdbe..63408699f4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist .idea .vscode tmp/ +requirements-musa.txt diff --git a/docs/CN/source/getting_started/installation.rst b/docs/CN/source/getting_started/installation.rst index fb998b7567..5fa0e304d2 100755 --- a/docs/CN/source/getting_started/installation.rst +++ b/docs/CN/source/getting_started/installation.rst @@ -27,7 +27,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 前请确保你的docker设置中已经分配了足够的共享内存,否则可能导致 $ # 服务无法正常启动。 $ # 1.如果是纯文本服务,建议分配2GB以上的共享内存, 如果你的内存充足,建议分配16GB以上的共享内存. - $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. + $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. $ # 如果你没有足够的共享内存,可以尝试在启动服务的时候调低 --running_max_req_size 参数,这会降低 $ # 服务的并发请求数量,但可以减少共享内存的占用。如果是多模态服务,也可以通过降低 --cache_capacity $ # 参数来减少共享内存的占用。 @@ -38,7 +38,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 你也可以使用源码手动构建镜像并运行,建议手动构建镜像,因为更新比较频繁: .. code-block:: console - + $ # 进入代码仓库的根目录 $ cd /lightllm $ # 手动构建镜像, docker 目录下有不同功能场景的镜像构建文件,按需构建。 @@ -52,7 +52,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 或者你也可以直接使用脚本一键启动镜像并且运行: .. code-block:: console - + $ # 查看脚本参数 $ python tools/quick_launch_docker.py --help @@ -80,6 +80,10 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 安装lightllm的依赖 (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # 安装lightllm的依赖 (摩尔线程 GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # 安装lightllm $ python setup.py install @@ -97,6 +101,6 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + 具体原因可以参考:`issue `_ 和 `fix PR `_ diff --git a/docs/EN/source/getting_started/installation.rst b/docs/EN/source/getting_started/installation.rst index 75fa714764..6439c48de3 100755 --- a/docs/EN/source/getting_started/installation.rst +++ b/docs/EN/source/getting_started/installation.rst @@ -24,16 +24,16 @@ The easiest way to install Lightllm is using the official image. You can directl $ docker pull ghcr.io/modeltc/lightllm:main $ $ # Run,The current LightLLM service relies heavily on shared memory. - $ # Before starting, please make sure that you have allocated enough shared memory + $ # Before starting, please make sure that you have allocated enough shared memory $ # in your Docker settings; otherwise, the service may fail to start properly. $ # - $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. + $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. $ # If your system has sufficient RAM, allocating 16GB or more is recommended. - $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. + $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. $ # You can adjust this value according to your specific requirements. $ # - $ # If you do not have enough shared memory available, you can try lowering - $ # the --running_max_req_size parameter when starting the service. + $ # If you do not have enough shared memory available, you can try lowering + $ # the --running_max_req_size parameter when starting the service. $ # This will reduce the number of concurrent requests, but also decrease shared memory usage. $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ @@ -42,13 +42,13 @@ The easiest way to install Lightllm is using the official image. You can directl You can also manually build the image from source and run it: .. code-block:: console - + $ # move into lightllm root dir $ cd /lightllm $ # Manually build the image $ docker build -t -f ./docker/Dockerfile . $ - $ # Run, + $ # Run, $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ $ /bin/bash @@ -56,7 +56,7 @@ You can also manually build the image from source and run it: Or you can directly use the script to launch the image and run it with one click: .. code-block:: console - + $ # View script parameters $ python tools/quick_launch_docker.py --help @@ -84,6 +84,10 @@ You can also install Lightllm from source: $ # Install Lightllm dependencies (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # Install Lightllm dependencies (Moore Threads GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # Install Lightllm $ python setup.py install @@ -101,5 +105,5 @@ You can also install Lightllm from source: .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + For specific reasons, please refer to: `issue `_ and `fix PR `_ \ No newline at end of file diff --git a/generate_requirements_musa.sh b/generate_requirements_musa.sh new file mode 100755 index 0000000000..f5bfb8ff83 --- /dev/null +++ b/generate_requirements_musa.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Script to generate requirements-musa.txt from requirements.txt +# MUSA is not compatible with CUDA packages, so they need to be removed +# Torch-related packages are pre-installed in the MUSA docker container + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +INPUT_FILE="${SCRIPT_DIR}/requirements.txt" +OUTPUT_FILE="${SCRIPT_DIR}/requirements-musa.txt" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: requirements.txt not found at $INPUT_FILE" + exit 1 +fi + +echo "Generating requirements-musa.txt from requirements.txt..." + +# Define patterns to remove (CUDA-specific packages) +# These packages are not compatible with MUSA +CUDA_PACKAGES=( + "^cupy" # cupy-cuda12x and similar + "^cuda_bindings" # CUDA bindings + "^nixl" # NIXL (NVIDIA Inter-node eXchange Library) + "^flashinfer" # flashinfer-python (CUDA-specific attention kernel) + "^sgl-kernel" # SGL kernel (CUDA-specific) +) + +# Define torch-related packages (pre-installed in MUSA container, remove version pins) +TORCH_PACKAGES=( + "^torch==" + "^torchvision==" +) + +# Create the output file with a header comment +cat > "$OUTPUT_FILE" << 'EOF' +# Requirements for MUSA (Moore Threads GPU) +# Auto-generated from requirements.txt by generate_requirements_musa.sh +# CUDA-specific packages have been removed +# Torch-related packages have version pins removed (pre-installed in MUSA container) + +EOF + +# Process the requirements file +while IFS= read -r line || [ -n "$line" ]; do + # Skip empty lines and comments (but keep them in output) + if [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]]; then + echo "$line" >> "$OUTPUT_FILE" + continue + fi + + # Extract package name (before ==, >=, <=, ~=, etc.) + pkg_name=$(echo "$line" | sed -E 's/^([a-zA-Z0-9_-]+).*/\1/') + + # Check if this is a CUDA package to skip + skip=false + for pattern in "${CUDA_PACKAGES[@]}"; do + if [[ "$pkg_name" =~ $pattern ]]; then + echo " Removing CUDA package: $line" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Check if this is a torch-related package (remove version pin) + for pattern in "${TORCH_PACKAGES[@]}"; do + if [[ "$line" =~ $pattern ]]; then + # Remove version pin, keep just the package name + pkg_only=$(echo "$line" | sed -E 's/==.*//') + echo " Unpinning version for: $pkg_only (pre-installed in MUSA container)" + echo "$pkg_only" >> "$OUTPUT_FILE" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Keep the package as-is + echo "$line" >> "$OUTPUT_FILE" + +done < "$INPUT_FILE" + +# Add MUSA-specific packages at the end +cat >> "$OUTPUT_FILE" << 'EOF' + +# MUSA-specific packages +torch_musa +torchada +EOF + +echo "" +echo "Successfully generated: $OUTPUT_FILE" +echo "" +echo "Summary of changes:" +echo " - Removed CUDA-specific packages: cupy-cuda12x, cuda_bindings, nixl, flashinfer-python, sgl-kernel" +echo " - Unpinned torch-related packages: torch, torchvision (pre-installed in MUSA container)" +echo " - Added MUSA-specific packages: torch_musa, torchada" + From e2b3305dc54a0c228259f089a80d8bab71a5d7b3 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 13 Jan 2026 12:20:53 +0000 Subject: [PATCH 21/43] quantization draft --- lightllm/common/quantization/__init__.py | 41 ++- lightllm/common/quantization/backend.py | 82 +++++ .../common/quantization/deepgemm_quant.py | 136 ------- .../quantization/triton_quant/triton_quant.py | 112 ------ .../common/quantization/types/__init__.py | 13 + .../{awq_quant.py => types/awq.py} | 332 ++++++++++-------- .../common/quantization/types/fp8_block128.py | 216 ++++++++++++ .../quantization/types/fp8_per_token.py | 172 +++++++++ .../quantization/{ => types}/no_quant.py | 7 +- lightllm/common/quantization/types/w8a8.py | 108 ++++++ lightllm/common/quantization/w8a8_quant.py | 253 ------------- 11 files changed, 801 insertions(+), 671 deletions(-) create mode 100644 lightllm/common/quantization/backend.py delete mode 100644 lightllm/common/quantization/deepgemm_quant.py delete mode 100644 lightllm/common/quantization/triton_quant/triton_quant.py create mode 100644 lightllm/common/quantization/types/__init__.py rename lightllm/common/quantization/{awq_quant.py => types/awq.py} (62%) create mode 100644 lightllm/common/quantization/types/fp8_block128.py create mode 100644 lightllm/common/quantization/types/fp8_per_token.py rename lightllm/common/quantization/{ => types}/no_quant.py (90%) create mode 100644 lightllm/common/quantization/types/w8a8.py delete mode 100644 lightllm/common/quantization/w8a8_quant.py diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index af1327cd89..d5289298cf 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -1,13 +1,21 @@ import yaml import collections from .registry import QUANTMETHODS -from .w8a8_quant import * -from .triton_quant.triton_quant import * -from .deepgemm_quant import * -from .awq_quant import * -from .no_quant import * +from .backend import QUANT_BACKEND from lightllm.utils.log_utils import init_logger +# Import all type classes (they auto-register with QUANTMETHODS) +from .types import ( + NoQuantization, + FP8Block128Quantization, + FP8PerTokenQuantization, + W8A8Quantization, + AWQQuantization, +) + +# Re-export for backwards compatibility +from .types.awq import is_awq_marlin_compatible + logger = init_logger(__name__) @@ -37,20 +45,21 @@ def _mapping_quant_method(self): if self.hf_quantization_method == "fp8": block_size = self.hf_quantization_config.get("weight_block_size", None) if block_size == [128, 128]: - from lightllm.common.quantization.deepgemm_quant import HAS_DEEPGEMM - - if HAS_DEEPGEMM: - self.quant_type = "deepgemm-fp8w8a8-b128" - else: - self.quant_type = "vllm-fp8w8a8-b128" - logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") + self.quant_type = "fp8-block128" + logger.info( + f"Selected quant type: fp8-block128, backend: {QUANT_BACKEND.get_backend('fp8-block128').name}" + ) + else: + self.quant_type = "fp8-per-token" + logger.info( + f"Selected quant type: fp8-per-token, backend: {QUANT_BACKEND.get_backend('fp8-per-token').name}" + ) elif self.hf_quantization_method == "awq": self.quant_type = "awq" - if is_awq_marlin_compatible(self.hf_quantization_config): - self.quant_type = "awq_marlin" - logger.info(f"select awq quant way: {self.quant_type}") + logger.info("Selected quant type: awq (marlin auto-selected if compatible)") else: - # TODO: more quant method + # TODO: more quant methods + raise NotImplementedError(f"Quant method {self.hf_quantization_method} not implemented yet.") pass def _parse_custom_cfg(self, custom_cfg_path): diff --git a/lightllm/common/quantization/backend.py b/lightllm/common/quantization/backend.py new file mode 100644 index 0000000000..e6d081ec27 --- /dev/null +++ b/lightllm/common/quantization/backend.py @@ -0,0 +1,82 @@ +import os +from enum import Enum, auto +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class BackendType(Enum): + TRITON = auto() + VLLM = auto() + DEEPGEMM = auto() + + +class BackendRegistry: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._initialized = True + + self._force_triton = os.getenv("LIGHTLLM_USE_TRITON_QUANT", "0").upper() in ["1", "TRUE", "ON"] + + self._has_vllm = self._check_vllm() + self._has_deepgemm = self._check_deepgemm() + + if self._force_triton: + logger.info("LIGHTLLM_USE_TRITON_QUANT is set, forcing Triton backend for quantization") + else: + logger.info(f"Available quantization backends: vLLM={self._has_vllm}, DeepGEMM={self._has_deepgemm}") + + def _check_vllm(self) -> bool: + try: + from lightllm.utils.vllm_utils import HAS_VLLM + + return HAS_VLLM + except ImportError: + return False + + def _check_deepgemm(self) -> bool: + try: + import deep_gemm # noqa: F401 + + return True + except ImportError: + return False + + @property + def force_triton(self) -> bool: + return self._force_triton + + @property + def has_vllm(self) -> bool: + return self._has_vllm + + @property + def has_deepgemm(self) -> bool: + return self._has_deepgemm + + def get_backend(self, quant_type: str) -> BackendType: + if self._force_triton: + return BackendType.TRITON + + if quant_type == "fp8-block128": + if self._has_deepgemm: + return BackendType.DEEPGEMM + elif self._has_vllm: + return BackendType.VLLM + elif quant_type in ["w8a8", "fp8-per-token"]: + if self._has_vllm: + return BackendType.VLLM + + return BackendType.TRITON + + +QUANT_BACKEND = BackendRegistry() diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py deleted file mode 100644 index 86dd9b5729..0000000000 --- a/lightllm/common/quantization/deepgemm_quant.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import torch -from torch.types import Device -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( - per_token_group_quant_fp8, - tma_align_input_scale, -) -from typing import TYPE_CHECKING, Optional - -from .quantize_method import WeightPack - -try: - HAS_DEEPGEMM = True - import deep_gemm -except: - HAS_DEEPGEMM = False - - -class DeepGEMMBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): - raise NotImplementedError("Not implemented") - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - raise NotImplementedError("Not implemented") - - @property - def method_name(self): - return "deepgemm-base" - - -@QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"]) -class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.block_size = 128 - self.weight_suffix = None - self.weight_zero_point_suffix = None - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - @property - def method_name(self): - return "deepgemm-fp8w8a8-b128" - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): - from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - - device = output.weight.device - weight, scale = weight_quant(weight.cuda(device), self.block_size) - output.weight[offset : offset + weight.shape[0], :].copy_(weight) - output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) - return - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: "WeightPack", - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - input_scale = None - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - m, k = input_tensor.shape - n = qweight.shape[0] - if input_scale is None: - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, - self.block_size, - dtype=qweight.dtype, - column_major_scales=True, - scale_tma_aligned=True, - alloc_func=alloc_func, - ) - - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) - return out - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty( - expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) - - def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) - return - - def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[ - start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size - ].copy_(weight_scale) - return - - def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_zero_point[ - start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size - ].copy_(weight_zero_point) - return - - -def _deepgemm_fp8_nt(a_tuple, b_tuple, out): - if HAS_DEEPGEMM: - if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): - return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) - if hasattr(deep_gemm, "fp8_gemm_nt"): - return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) - raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") diff --git a/lightllm/common/quantization/triton_quant/triton_quant.py b/lightllm/common/quantization/triton_quant/triton_quant.py deleted file mode 100644 index 9f6a7bee27..0000000000 --- a/lightllm/common/quantization/triton_quant/triton_quant.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import torch -import torch.nn.functional as F -from lightllm.common.quantization.quantize_method import QuantizationMethod -from lightllm.common.quantization.registry import QUANTMETHODS -from .fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul -from .fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from typing import TYPE_CHECKING, Optional - -from lightllm.common.quantization.quantize_method import WeightPack - - -class TritonBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> WeightPack: - raise NotImplementedError("Not implemented") - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - raise NotImplementedError("Not implemented") - - -@QUANTMETHODS.register(["triton-fp8w8a8-block128"]) -class TritonFP8w8a8QuantizationMethod(TritonBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.is_moe = False - self.block_size = 128 - self.weight_suffix = None - self.weight_zero_point_suffix = None - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - # TODO block-wise quant kernel - raise NotImplementedError("Not implemented") - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - input_scale = None - m, k = input_tensor.shape - n = qweight.shape[1] - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - if input_scale is None: - input_tensor_q, input_scale = per_token_group_quant_fp8( - input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func - ) - else: - # TODO - raise "statci input scale is not supported by triton fp8 block gemm kernel." - m = input_tensor.shape[0] - n = qweight.shape[1] - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - w8a8_block_fp8_matmul( - input_tensor_q, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - return out - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty( - expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) - - def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) - return - - def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[ - start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size - ].copy_(weight_scale) - return - - def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_zero_point[ - start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size - ].copy_(weight_zero_point) - return diff --git a/lightllm/common/quantization/types/__init__.py b/lightllm/common/quantization/types/__init__.py new file mode 100644 index 0000000000..8cbcc2e684 --- /dev/null +++ b/lightllm/common/quantization/types/__init__.py @@ -0,0 +1,13 @@ +from .no_quant import NoQuantization +from .fp8_block128 import FP8Block128Quantization +from .fp8_per_token import FP8PerTokenQuantization +from .w8a8 import W8A8Quantization +from .awq import AWQQuantization + +__all__ = [ + "NoQuantization", + "FP8Block128Quantization", + "FP8PerTokenQuantization", + "W8A8Quantization", + "AWQQuantization", +] diff --git a/lightllm/common/quantization/awq_quant.py b/lightllm/common/quantization/types/awq.py similarity index 62% rename from lightllm/common/quantization/awq_quant.py rename to lightllm/common/quantization/types/awq.py index d523cce757..eedc5b67b5 100644 --- a/lightllm/common/quantization/awq_quant.py +++ b/lightllm/common/quantization/types/awq.py @@ -1,66 +1,78 @@ -import os import torch -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from typing import Any -from typing import TYPE_CHECKING, Optional, Tuple -from lightllm.utils.dist_utils import get_current_device_id - -from .quantize_method import WeightPack - -if HAS_VLLM: - awq_dequantize = vllm_ops.awq_dequantize - awq_gemm = vllm_ops.awq_gemm - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, - marlin_permute_scales, - awq_to_marlin_zero_points, - should_use_atomic_add_reduce, - marlin_make_empty_g_idx, - marlin_make_workspace_new, - ) - from vllm.scalar_type import scalar_types +from typing import Any, Optional, Tuple - TYPE_MAP = { - 4: scalar_types.uint4, - 8: scalar_types.uint8, - } +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops + + if HAS_VLLM: + awq_dequantize = vllm_ops.awq_dequantize + awq_gemm = vllm_ops.awq_gemm + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + marlin_permute_scales, + awq_to_marlin_zero_points, + should_use_atomic_add_reduce, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + ) + from vllm.scalar_type import scalar_types + + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + else: + awq_dequantize = None + awq_gemm = None + TYPE_MAP = {} +except ImportError: + HAS_VLLM = False + awq_dequantize = None + awq_gemm = None + TYPE_MAP = {} + + +def is_awq_marlin_compatible(quantization_config: dict[str, Any]) -> bool: + if not HAS_VLLM: + return False + quant_method = quantization_config.get("quant_method", "").lower() + num_bits = quantization_config.get("bits") + group_size = quantization_config.get("group_size") + zero_point = quantization_config.get("zero_point") -class AWQBaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + if not torch.cuda.is_available(): + return False - self.cache_manager = g_cache_manager + if quant_method != "awq": + return False - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): - raise NotImplementedError("AWQ online quantization is not supported yet.") + if num_bits is None or group_size is None or zero_point is None: + return False - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - raise NotImplementedError("AWQ online quantization is not supported yet.") + if num_bits not in TYPE_MAP: + return False - @property - def method_name(self): - return "awq-base" + return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) -@QUANTMETHODS.register("awq") -class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): +@QUANTMETHODS.register(["awq", "awq_marlin"]) +class AWQQuantization(QuantizationMethod): def __init__(self): super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + if not HAS_VLLM: + raise RuntimeError("vLLM is required for AWQ quantization but is not installed.") + + self.cache_manager = g_cache_manager self.pack_factor = 8 self.weight_scale_suffix = "scales" self.weight_zero_point_suffix = "qzeros" @@ -68,11 +80,38 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = True + self._use_marlin = False + self._marlin_initialized = False + + def _init_marlin(self): + if self._marlin_initialized: + return + + self.nbits = 4 + self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) + self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) + self.workspace = marlin_make_workspace_new(torch.device("cuda")) + self.vllm_quant_type = TYPE_MAP[self.nbits] + self.tile_size = 16 + self._marlin_initialized = True + + def _check_and_set_marlin(self): + if self.hf_quantization_config is None: + self._use_marlin = False + return + + self._use_marlin = is_awq_marlin_compatible(self.hf_quantization_config) + if self._use_marlin: + self._init_marlin() + logger.info("AWQQuantization using Marlin backend") + else: + logger.info("AWQQuantization using basic AWQ backend") + @property def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( @@ -83,6 +122,22 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if not hasattr(self, "_checked_marlin"): + self._check_and_set_marlin() + self._checked_marlin = True + + if self._use_marlin: + return self._apply_marlin(input_tensor, weight_pack, out, bias) + else: + return self._apply_basic(input_tensor, weight_pack, out, bias) + + def _apply_basic( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + bias: Optional[torch.Tensor], ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -99,81 +154,12 @@ def apply( out.add_(bias) return out - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - group_size = self.hf_quantization_config["group_size"] - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) - weight_zero_point = torch.empty( - expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) - - def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - start_idx = start_idx // self.pack_factor - weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) - return - - def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) - return - - def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - start_idx = start_idx // self.pack_factor - end_idx = start_idx + weight_zero_point.shape[1] - weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) - return - - -@QUANTMETHODS.register("awq_marlin") -class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): - def __init__(self): - super().__init__() - self.pack_factor = 8 - self.nbits = 4 - self.weight_scale_suffix = "scales" - self.weight_zero_point_suffix = "qzeros" - self.weight_suffix = "qweight" - self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) - self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) - self.workspace = marlin_make_workspace_new(torch.device("cuda")) - self.vllm_quant_type = TYPE_MAP[self.nbits] - self.has_weight_scale = True - self.has_weight_zero_point = True - self.tile_size = 16 - - @property - def method_name(self): - return "awq_marlin" - - def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: - raise NotImplementedError("AWQ online quantization is not supported yet.") - - def params_repack( - self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 - """ - weight = self._process_weight_after_loading(weight.cuda(get_current_device_id())) - weight_scale = self._process_weight_scale_after_loading( - weight_scale.cuda(get_current_device_id()).to(dtype_type) - ) - weight_zero_point = self._process_weight_zero_point_after_loading( - weight_zero_point.cuda(get_current_device_id()) - ) - return weight, weight_scale, weight_zero_point - - def apply( + def _apply_marlin( self, input_tensor: torch.Tensor, weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor], + bias: Optional[torch.Tensor], ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -214,6 +200,30 @@ def apply( def create_weight( self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + if not hasattr(self, "_checked_marlin"): + self._check_and_set_marlin() + self._checked_marlin = True + + if self._use_marlin: + return self._create_weight_marlin(out_dim, in_dim, dtype, device_id, num_experts) + else: + return self._create_weight_basic(out_dim, in_dim, dtype, device_id, num_experts) + + def _create_weight_basic( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def _create_weight_marlin( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 ) -> WeightPack: self.n = out_dim self.k = in_dim @@ -229,6 +239,20 @@ def create_weight( return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + if not hasattr(self, "_checked_marlin"): + self._check_and_set_marlin() + self._checked_marlin = True + + if self._use_marlin: + self._load_weight_marlin(weight, weight_pack, start_idx) + else: + self._load_weight_basic(weight, weight_pack, start_idx) + + def _load_weight_basic(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) + + def _load_weight_marlin(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: assert self.hf_quantization_config is not None, "hf_quantization_config is not set" device_id = get_current_device_id() repack_weight = vllm_ops.awq_marlin_repack( @@ -239,9 +263,21 @@ def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: ) start_idx = start_idx // self.pack_factor * self.tile_size weight_pack.weight[:, start_idx : start_idx + repack_weight.shape[1]].copy_(repack_weight) - return def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + if not hasattr(self, "_checked_marlin"): + self._check_and_set_marlin() + self._checked_marlin = True + + if self._use_marlin: + self._load_weight_scale_marlin(weight_scale, weight_pack, start_idx) + else: + self._load_weight_scale_basic(weight_scale, weight_pack, start_idx) + + def _load_weight_scale_basic(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) + + def _load_weight_scale_marlin(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: assert self.hf_quantization_config is not None, "hf_quantization_config is not set" group_size = self.hf_quantization_config["group_size"] device_id = get_current_device_id() @@ -252,9 +288,27 @@ def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, group_size=self.hf_quantization_config["group_size"], ) weight_pack.weight_scale[:, start_idx : start_idx + repack_weight_scale.shape[1]].copy_(repack_weight_scale) - return def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + if not hasattr(self, "_checked_marlin"): + self._check_and_set_marlin() + self._checked_marlin = True + + if self._use_marlin: + self._load_weight_zero_point_marlin(weight_zero_point, weight_pack, start_idx) + else: + self._load_weight_zero_point_basic(weight_zero_point, weight_pack, start_idx) + + def _load_weight_zero_point_basic( + self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int + ) -> None: + start_idx = start_idx // self.pack_factor + end_idx = start_idx + weight_zero_point.shape[1] + weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) + + def _load_weight_zero_point_marlin( + self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int + ) -> None: device_id = get_current_device_id() repack_weight_zero_point = awq_to_marlin_zero_points( weight_zero_point.cuda(device_id), @@ -266,29 +320,3 @@ def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: W weight_pack.weight_zero_point[:, start_idx : start_idx + repack_weight_zero_point.shape[1]].copy_( repack_weight_zero_point ) - return - - -# adapted from -# https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 -def is_awq_marlin_compatible(quantization_config: dict[str, Any]): - # Extract data from quant config. - quant_method = quantization_config.get("quant_method", "").lower() - num_bits = quantization_config.get("bits") - group_size = quantization_config.get("group_size") - zero_point = quantization_config.get("zero_point") - - if not torch.cuda.is_available(): - return False - - if quant_method != "awq": - return False - - # If we cannot find the info needed in the config, cannot convert. - if num_bits is None or group_size is None or zero_point is None: - return False - - if num_bits not in TYPE_MAP: - return False - - return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) diff --git a/lightllm/common/quantization/types/fp8_block128.py b/lightllm/common/quantization/types/fp8_block128.py new file mode 100644 index 0000000000..4144dddde2 --- /dev/null +++ b/lightllm/common/quantization/types/fp8_block128.py @@ -0,0 +1,216 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType +from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + import deep_gemm + + HAS_DEEPGEMM = True +except ImportError: + HAS_DEEPGEMM = False + +try: + from lightllm.utils.vllm_utils import HAS_VLLM + + if HAS_VLLM: + from lightllm.utils.vllm_utils import cutlass_scaled_mm + else: + cutlass_scaled_mm = None +except ImportError: + HAS_VLLM = False + cutlass_scaled_mm = None + + +def _deepgemm_fp8_nt(a_tuple, b_tuple, out): + if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): + return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) + if hasattr(deep_gemm, "fp8_gemm_nt"): + return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) + raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") + + +@QUANTMETHODS.register(["fp8-block128"]) +class FP8Block128Quantization(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + self.block_size = 128 + self.weight_scale_suffix = "weight_scale_inv" + self.has_weight_scale = True + self.has_weight_zero_point = False + + self._backend = QUANT_BACKEND.get_backend("fp8-block128") + logger.info(f"FP8Block128Quantization using backend: {self._backend.name}") + + @property + def method_name(self): + return "fp8-block128" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant + + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty + m, k = input_tensor.shape + + if self._backend == BackendType.DEEPGEMM: + return self._apply_deepgemm(input_tensor, weight_pack, out, alloc_func, bias) + elif self._backend == BackendType.VLLM: + return self._apply_vllm(input_tensor, weight_pack, out, alloc_func, bias) + else: + return self._apply_triton(input_tensor, weight_pack, out, alloc_func, bias) + + def _apply_deepgemm( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + alloc_func, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + m, k = input_tensor.shape + n = qweight.shape[0] + + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, + self.block_size, + dtype=qweight.dtype, + column_major_scales=True, + scale_tma_aligned=True, + alloc_func=alloc_func, + ) + + if out is None: + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) + + if bias is not None: + out.add_(bias) + return out + + def _apply_vllm( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + alloc_func, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale.t() + m, k = input_tensor.shape + n = qweight.shape[1] + + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func + ) + + if out is None: + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + if n % 128 != 0: + w8a8_block_fp8_matmul( + qinput_tensor, + qweight, + input_scale, + weight_scale, + out, + (self.block_size, self.block_size), + dtype=input_tensor.dtype, + ) + else: + input_scale = input_scale.t().contiguous().t() + cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) + return out + + if bias is not None: + out.add_(bias) + return out + + def _apply_triton( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + alloc_func, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + m, k = input_tensor.shape + n = qweight.shape[1] + + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func + ) + + if out is None: + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + w8a8_block_fp8_matmul( + qinput_tensor, + qweight, + input_scale, + weight_scale, + out, + (self.block_size, self.block_size), + dtype=input_tensor.dtype, + ) + + if bias is not None: + out.add_(bias) + return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + if weight_pack.weight_zero_point is not None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return diff --git a/lightllm/common/quantization/types/fp8_per_token.py b/lightllm/common/quantization/types/fp8_per_token.py new file mode 100644 index 0000000000..c49bc89ff3 --- /dev/null +++ b/lightllm/common/quantization/types/fp8_per_token.py @@ -0,0 +1,172 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType +from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + from lightllm.utils.vllm_utils import HAS_VLLM + + if HAS_VLLM: + from lightllm.utils.vllm_utils import vllm_ops, cutlass_scaled_mm + else: + vllm_ops = None + cutlass_scaled_mm = None +except ImportError: + HAS_VLLM = False + vllm_ops = None + cutlass_scaled_mm = None + +try: + from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops + + if HAS_LIGHTLLM_KERNEL: + + def scaled_fp8_quant(tensor, *args, **kwargs): + return light_ops.per_token_quant_bf16_fp8(tensor) + + else: + if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant + else: + scaled_fp8_quant = None +except ImportError: + HAS_LIGHTLLM_KERNEL = False + if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant + else: + scaled_fp8_quant = None + + +@QUANTMETHODS.register(["fp8-per-token", "fp8w8a8"]) +class FP8PerTokenQuantization(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + self.is_moe = False + self.has_weight_scale = True + self.has_weight_zero_point = False + self._backend = QUANT_BACKEND.get_backend("fp8-per-token") + logger.info(f"FP8PerTokenQuantization using backend: {self._backend.name}") + + @property + def method_name(self): + return "fp8-per-token" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + """Quantize weights using per-token FP8 quantization.""" + if self.is_moe: + return self._quantize_moe(weight, output, offset) + + if scaled_fp8_quant is None: + raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") + + qweight, weight_scale = scaled_fp8_quant( + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + ) + output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) + output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) + return + + def _quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int) -> None: + if scaled_fp8_quant is None: + raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") + + num_experts = weight.shape[0] + qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) + weight_scales = [] + for i in range(num_experts): + qweight, weight_scale = scaled_fp8_quant( + weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + ) + qweights[i] = qweight + weight_scales.append(weight_scale) + weight_scale = torch.stack(weight_scales, dim=0).contiguous() + output.weight.copy_(qweights) + output.weight_scale.copy_(weight_scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self._backend == BackendType.TRITON: + return self._apply_triton(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + else: + return self._apply_vllm(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + + def _apply_vllm( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) + return out + + def _apply_triton( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + if scaled_fp8_quant is None: + raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") + + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) + + if bias is not None: + out.add_(bias) + return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/types/no_quant.py similarity index 90% rename from lightllm/common/quantization/no_quant.py rename to lightllm/common/quantization/types/no_quant.py index 987601c5d6..e92d821c15 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/types/no_quant.py @@ -1,11 +1,14 @@ -from .quantize_method import QuantizationMethod, WeightPack -from .registry import QUANTMETHODS import torch from typing import Optional +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS + @QUANTMETHODS.register("none") class NoQuantization(QuantizationMethod): + """No quantization - uses full precision weights.""" + def apply( self, input_tensor: torch.Tensor, diff --git a/lightllm/common/quantization/types/w8a8.py b/lightllm/common/quantization/types/w8a8.py new file mode 100644 index 0000000000..e3b0ef592b --- /dev/null +++ b/lightllm/common/quantization/types/w8a8.py @@ -0,0 +1,108 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# Conditional imports for optional backends +try: + from lightllm.utils.vllm_utils import HAS_VLLM + + if HAS_VLLM: + from lightllm.utils.vllm_utils import vllm_ops, cutlass_scaled_mm + else: + vllm_ops = None + cutlass_scaled_mm = None +except ImportError: + HAS_VLLM = False + vllm_ops = None + cutlass_scaled_mm = None + + +@QUANTMETHODS.register(["w8a8", "vllm-w8a8"]) +class W8A8Quantization(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + self.has_weight_scale = True + self.has_weight_zero_point = False + + self._backend = QUANT_BACKEND.get_backend("w8a8") + + if self._backend == BackendType.TRITON: + if not HAS_VLLM: + raise NotImplementedError( + "W8A8 Triton fallback is not yet implemented. " + "Please install vLLM or disable LIGHTLLM_USE_TRITON_QUANT." + ) + self._backend = BackendType.VLLM + logger.warning("W8A8 Triton fallback not implemented, falling back to vLLM backend") + + if self._backend == BackendType.VLLM and not HAS_VLLM: + raise RuntimeError("vLLM is required for W8A8 quantization but is not installed.") + + logger.info(f"W8A8Quantization using backend: {self._backend.name}") + + @property + def method_name(self): + return "w8a8" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + weight = weight.float().cuda(self.device_id_) + scale = weight.abs().max(dim=-1)[0] / 127 + weight = weight / scale.reshape(-1, 1) + weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) + output.weight[offset : offset + weight.shape[0]].copy_(weight) + output.weight_scale[offset : offset + weight.shape[0]].copy_(scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO: Currently only vLLM backend is implemented + return self._apply_vllm(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + + def _apply_vllm( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) + return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/common/quantization/w8a8_quant.py b/lightllm/common/quantization/w8a8_quant.py deleted file mode 100644 index 1728e799db..0000000000 --- a/lightllm/common/quantization/w8a8_quant.py +++ /dev/null @@ -1,253 +0,0 @@ -import os -import torch - -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token -from .quantize_method import QuantizationMethod -from .registry import QUANTMETHODS -import torch.nn.functional as F -from typing import Optional, TYPE_CHECKING -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul -from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - -from .quantize_method import WeightPack - -if HAS_LIGHTLLM_KERNEL: - - def scaled_fp8_quant(tensor, *args, **kwargs): - return light_ops.per_token_quant_bf16_fp8(tensor) - -else: - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant - -LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [ - "ON", - "TRUE", - "1", -] - - -class BaseQuantizationMethod(QuantizationMethod): - def __init__(self): - super().__init__() - assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - raise NotImplementedError("Not implemented") - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - raise NotImplementedError("Not implemented") - - @property - def method_name(self): - return "w8a8-base" - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - raise NotImplementedError("Not implemented") - - -@QUANTMETHODS.register(["vllm-w8a8", "w8a8"]) -class w8a8QuantizationMethod(BaseQuantizationMethod): - def __init__(self): - super().__init__() - self.has_weight_scale = True - self.has_weight_zero_point = False - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - weight = weight.float().cuda(self.device_id_) - scale = weight.abs().max(dim=-1)[0] / 127 - weight = weight / scale.reshape(-1, 1) - weight = torch.round(weight.clamp(min=-128, max=127)).to(dtype=torch.int8) - output.weight[offset : offset + weight.shape[0]].copy_(weight) - output.weight_scale[offset : offset + weight.shape[0]].copy_(scale) - return - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - input_scale = None - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - input_scale = None # dynamic quantization for input tensor - x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) - m = input_tensor.shape[0] - n = qweight.shape[1] - if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) - return out - - @property - def method_name(self): - return "vllm-w8a8" - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) - - -@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"]) -class FP8w8a8QuantizationMethod(BaseQuantizationMethod): - def __init__(self): - super().__init__() - self.is_moe = False - self.has_weight_scale = True - self.has_weight_zero_point = False - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - if self.is_moe: - return self.quantize_moe(weight, output, offset) - qweight, weight_scale = scaled_fp8_quant( - weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) - output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) - return - - def quantize_moe(self, weight: torch.Tensor) -> WeightPack: - num_experts = weight.shape[0] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) - weight_scales = [] - for i in range(num_experts): - qweight, weight_scale = scaled_fp8_quant( - weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.stack(weight_scales, dim=0).contiguous() - return WeightPack(weight=qweights, weight_scale=weight_scale) - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - m = input_tensor.shape[0] - n = qweight.shape[1] - if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - if LIGHTLLM_USE_TRITON_FP8_SCALED_MM: - out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) - else: - cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) - return out - - @property - def method_name(self): - return "vllm-fp8w8a8" - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) - - -@QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"]) -class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): - def __init__(self): - super().__init__() - self.block_size = 128 - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - - device = output.weight.device - weight, scale = weight_quant(weight.cuda(device), self.block_size) - output.weight[offset : offset + weight.shape[0], :].copy_(weight) - output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) - return - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale.t() - input_scale = None # dynamic quantization for input tensor - m, k = input_tensor.shape - n = qweight.shape[1] - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - if input_scale is None: - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func - ) - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - if n % 128 != 0: - w8a8_block_fp8_matmul( - qinput_tensor, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - else: - input_scale = input_scale.t().contiguous().t() - cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) - return out - - @property - def method_name(self): - return "vllm-fp8w8a8-b128" - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty( - expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) From f0481a886167e51d5966c9c6bcd311cd38926a3f Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:56:45 +0800 Subject: [PATCH 22/43] fix openai v1 (#1178) Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> --- lightllm/models/qwen2/model.py | 2 +- lightllm/server/api_models.py | 5 +++-- lightllm/server/api_openai.py | 2 +- lightllm/server/core/objs/sampling_params.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index d2f067c42c..106610ff09 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -18,7 +18,7 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - if self.config["sliding_window"] is None: + if self.config.get("sliding_window", None) is None: self.config["sliding_window"] = self.max_total_token_num # rename key [SYM: to be confirmed] return diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7b9cdd5012..f30ecc55fe 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -24,9 +24,10 @@ class Message(BaseModel): class Function(BaseModel): """Function descriptions.""" - description: Optional[str] = Field(default=None, examples=[None]) name: Optional[str] = None - parameters: Optional[object] = None + description: Optional[str] = Field(default=None, examples=[None]) + parameters: Optional[dict] = None + response: Optional[dict] = None class Tool(BaseModel): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 6a8c232dc5..d91bb1d947 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -81,7 +81,7 @@ def _process_tool_call_id( # SGLang sets call_item.tool_index to the *local* position inside that message. # Therefore, the index must be corrected by using # `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered. - tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}" + tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt + call_item.tool_index}" logger.debug( f"Process tool call idx, parser: {tool_call_parser}, \ tool_call_id: {tool_call_id}, \ diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f073319d79..d955aa6a87 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -7,7 +7,7 @@ _SAMPLING_EPS = 1e-5 DEFAULT_INPUT_PENALTY = os.getenv("INPUT_PENALTY", "False").upper() in ["ON", "TRUE", "1"] -SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "True").upper() in ["ON", "TRUE", "1"] +SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "False").upper() in ["ON", "TRUE", "1"] # 从环境变量获取最大长度限制 STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256)) From 2fbd2b8a7b180570129bd539bb34cbe0a5dbb22a Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Thu, 15 Jan 2026 13:12:48 +0800 Subject: [PATCH 23/43] add diverse_stage2 add optimize diverse_stage1 (#1174) Co-authored-by: wangzaijun --- ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + .../basemodel/attention/triton/int8kv.py | 2 +- ...se.py => int8kv_flash_decoding_diverse.py} | 38 +-- ...> int8kv_flash_decoding_diverse_stage1.py} | 32 +- .../int8kv_flash_decoding_diverse_stage2.py | 306 ++++++++++++++++++ ...> int8kv_flash_decoding_diverse_stage3.py} | 0 ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 22 ++ ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 22 ++ .../benchmark/static_inference/model_infer.py | 5 +- .../llama_gqa_diverse_decode_stage1_tuning.py | 2 +- .../llama_gqa_diverse_decode_stage2_tuning.py | 296 +++++++++++++++++ ... => test_int8kv_flash_decoding_diverse.py} | 27 +- ...t_int8kv_flash_decoding_diverse_stage1.py} | 108 ++++++- ...st_int8kv_flash_decoding_diverse_stage2.py | 293 +++++++++++++++++ ...t_int8kv_flash_decoding_diverse_stage3.py} | 2 +- ...pl_int8kv_flash_decoding_diverse_stage2.py | 132 -------- 77 files changed, 1160 insertions(+), 189 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse.py => int8kv_flash_decoding_diverse.py} (75%) rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse_stage1.py => int8kv_flash_decoding_diverse_stage1.py} (90%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse_stage3.py => int8kv_flash_decoding_diverse_stage3.py} (100%) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 test/kernel/llama_gqa_diverse_decode_stage2_tuning.py rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse.py => test_int8kv_flash_decoding_diverse.py} (85%) rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse_stage1.py => test_int8kv_flash_decoding_diverse_stage1.py} (53%) create mode 100644 unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse_stage3.py => test_int8kv_flash_decoding_diverse_stage3.py} (96%) delete mode 100644 unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..67dd3852c5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..555386ebdd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..6f92439c1c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..67dd3852c5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..555386ebdd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..6f92439c1c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7f69e86a86 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7d8dc868c3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..3e543b2ea3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7f69e86a86 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7d8dc868c3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..3e543b2ea3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..a3b0edde6d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..22d1ce6f69 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..328fcec837 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..a3b0edde6d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..22d1ce6f69 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..328fcec837 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..4c4ae86241 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..61884a9375 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..037bfd2913 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..4c4ae86241 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..61884a9375 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..037bfd2913 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..e2028e2d2a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7e99dc1be2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d6b46dda8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..e2028e2d2a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7e99dc1be2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d6b46dda8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7795b47e72 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..ff4d6efd49 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..1d8ca6967b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7795b47e72 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..ff4d6efd49 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..1d8ca6967b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..2f1cd5dfd5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a369088bfb --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..cb4e6a0d3e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..2f1cd5dfd5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a369088bfb --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..cb4e6a0d3e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..60827b791e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7b42cad466 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..9bb49d70b7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..60827b791e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7b42cad466 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..9bb49d70b7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..bd3d1c418b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..f1b3539f5d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..e12c05b966 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..bd3d1c418b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..f1b3539f5d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..e12c05b966 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..c83dca52d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..59a3e1051c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b7c4eaa9f --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..c83dca52d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..59a3e1051c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b7c4eaa9f --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..abd760af04 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 3}, "16": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "64": {"8": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 4}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 9}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a560ce9e1a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "64": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 7}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "16": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 7}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index 6a795c4376..975d7b629c 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -158,7 +158,7 @@ def diverse_decode_att( alloc_func=torch.empty, ) -> torch.Tensor: - from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + from ...triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py similarity index 75% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index 6efb030ce6..ad6a8b5b3a 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -2,8 +2,9 @@ import torch from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.infer_struct import InferStateInfo -from .ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 -from .ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 +from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size @@ -37,10 +38,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" ) current_stream = torch.cuda.current_stream() @@ -65,21 +66,20 @@ def token_decode_attention_flash_decoding( ) stream2.wait_stream(current_stream) with torch.cuda.stream(stream2): - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.b_shared_seq_len, - infer_state.max_kv_seq_len, + flash_decode_stage2( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + b_shared_seq_len=infer_state.b_shared_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, ) current_stream.wait_stream(stream1) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py similarity index 90% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 7403f6dd5c..4dfaffef68 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -9,7 +9,7 @@ class GQADiverseDecodeStage1KernelConfig(KernelConfigs): - kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v1" + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v2" @classmethod @lru_cache(maxsize=200) @@ -113,6 +113,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( BLOCK_N: tl.constexpr, BLOCK_BATCH: tl.constexpr, KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, ): cur_batch = tl.program_id(0) shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch) @@ -128,6 +129,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( cur_q_head_range = tl.where(cur_q_head_range < q_head_end_index, cur_q_head_range, cur_kv_head * gqa_group_size) offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) cur_batch_seq_len = tl.load(b_shared_seq_len + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_start_index = seq_start_block * BLOCK_SEQ @@ -162,25 +164,37 @@ def _fwd_kernel_flash_decode_diverse_stage1( mask=n_mask, other=0, ).to(tl.int64) - off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] - off_k_scale = off_k // KV_QUANT_GROUP_SIZE + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) att_value = tl.dot(q, k.to(q.dtype)) att_value *= sm_scale att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] v = tl.load( - V + off_k.T, + V + off_v, mask=n_mask[:, None], other=0, ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) v_scale = tl.load( - V_scale + off_k_scale.T, - mask=n_mask[:, None], + V_scale + off_k_scale, + mask=n_mask[None, :], other=0.0, ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) cur_max_logic = tl.max(att_value, axis=1) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -274,7 +288,10 @@ def flash_decode_stage1( BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) if BLOCK_HEAD * BLOCK_BATCH < 16: BLOCK_BATCH = 16 // BLOCK_HEAD - + assert k.stride() == v.stride() + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, @@ -314,6 +331,7 @@ def flash_decode_stage1( BLOCK_N=BLOCK_N, BLOCK_BATCH=BLOCK_BATCH, KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 0000000000..f5c0b9c395 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,306 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict +from lightllm.common.triton_utils.autotuner import autotune, Autotuner + + +class GQADiverseDecodeStage2KernelConfig(KernelConfigs): + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage2:v1" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + ) -> dict: + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + batch_size_config: dict = finded_config[ + min( + finded_config.keys(), + key=lambda x: abs(int(x) - avg_seq_len_in_batch), + ) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + config_json: Dict[int, Dict[int, Dict]], + ): + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def _fwd_kernel_flash_decode_diverse_stage2( + Q, + stride_qbs, + stride_qh, + stride_qd, + K, + K_scale, + stride_kbs, + stride_kh, + stride_kd, + V, + V_scale, + stride_vbs, + stride_vh, + stride_vd, + sm_scale, + Req_to_tokens, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + B_req_idx, + B_Seqlen, + b_shared_seq_len, + Mid_O, # [batch, head, seq_block_num, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, gqa_group_size) + + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) + cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_shared_len + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + store_seq_block = seq_start_block + tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + + block_n_size = tl.cdiv( + tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), + BLOCK_N, + ) + + if block_n_size == 0: + return + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + q = tl.load(Q + off_q) + + sum_exp = tl.zeros([gqa_group_size], dtype=tl.float32) + max_logic = tl.zeros([gqa_group_size], dtype=tl.float32) - float("inf") + acc = tl.zeros([gqa_group_size, BLOCK_HEADDIM], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + # q (4, 128) k (128, BLOCK_N) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + store_seq_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + store_seq_block + tl.store( + Mid_O + off_mid_o, + (acc / sum_exp[:, None]), + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + (max_logic + tl.log(sum_exp)), + ) + return + + +@torch.no_grad() +def flash_decode_stage2( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + Req_to_tokens: torch.Tensor, + B_req_idx: torch.Tensor, + B_Seqlen: torch.Tensor, + b_shared_seq_len: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, + run_config: Optional[dict] = None, +): + if not run_config: + run_config = GQADiverseDecodeStage2KernelConfig.try_to_get_best_config( + batch_size=int(q.shape[0]), + avg_seq_len_in_batch=max_len_in_batch, + gqa_group_size=int(q.shape[1] // k.shape[1]), + q_head_dim=int(q.shape[2]), + block_seq=block_seq, + out_dtype=q.dtype, + ) + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + BLOCK_SEQ = block_seq + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + assert triton.next_power_of_2(Lk) == Lk + KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] + assert KV_QUANT_GROUP_SIZE == 8 + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + + assert k.stride() == v.stride() + + _fwd_kernel_flash_decode_diverse_stage2[grid]( + Q=q, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + K=k, + K_scale=k_scale, + stride_kbs=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + V=v, + V_scale=v_scale, + stride_vbs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + sm_scale=sm_scale, + Req_to_tokens=Req_to_tokens, + stride_req_to_tokens_b=Req_to_tokens.stride(0), + stride_req_to_tokens_s=Req_to_tokens.stride(1), + B_req_idx=B_req_idx, + B_Seqlen=B_Seqlen, + b_shared_seq_len=b_shared_seq_len, + Mid_O=mid_out, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] + stride_mid_o_eb=mid_out_logsumexp.stride(0), + stride_mid_o_eh=mid_out_logsumexp.stride(1), + stride_mid_o_es=mid_out_logsumexp.stride(2), + gqa_group_size=gqa_group_size, + BLOCK_SEQ=block_seq, + BLOCK_HEADDIM=Lk, + BLOCK_N=BLOCK_N, + KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, + num_warps=num_warps, + num_stages=num_stages, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..9f44ee6c30 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..4fa2f949f2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 32, + "num_stages": 4, + "num_warps": 4 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 3, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index a8abd2ae64..7f1c2b493f 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -41,7 +41,10 @@ def test_model_inference(args): "run_mode": "normal", "max_seq_length": args.max_req_total_len, "disable_cudagraph": args.disable_cudagraph, - "mode": args.mode, + "llm_prefill_att_backend": args.llm_prefill_att_backend, + "llm_decode_att_backend": args.llm_decode_att_backend, + "llm_kv_type": args.llm_kv_type, + "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, diff --git a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py index d391d30650..f32b093448 100644 --- a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py +++ b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py @@ -4,7 +4,7 @@ import torch.multiprocessing as mp from typing import List from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( flash_decode_stage1, GQADiverseDecodeStage1KernelConfig, ) diff --git a/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py new file mode 100644 index 0000000000..13c8945e59 --- /dev/null +++ b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py @@ -0,0 +1,296 @@ +import torch +import os +import torch.multiprocessing as mp +from typing import List +from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, + GQADiverseDecodeStage2KernelConfig, +) +from lightllm.utils.watchdog_utils import Watchdog + +logger = init_logger(__name__) + + +def set_seed(): + import torch + import random + import numpy as np + + seed = 42 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return + + +@torch.no_grad() +def test_decode_attentions( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int = 20, + **run_config, +): + set_seed() + shared_seq_len = 0 + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + max_len_in_batch = 8192 + quant_group_size = 8 + + args = [] + for _ in range(test_count): + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=dtype, device="cuda") / 10 + kv_shape = (batch_size * seq_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, seq_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), + dtype=q.dtype, + device="cuda", + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), + dtype=q.dtype, + device="cuda", + ) + arg_list, kwargs = ( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + b_seq_len, + b_shared_seq_len, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, + ), dict(run_config=run_config) + args.append((arg_list, kwargs)) + + graph = torch.cuda.CUDAGraph() + arg_list, kwargs = args[0] + flash_decode_stage2(*arg_list, **kwargs) + with torch.cuda.graph(graph): + for index in range(test_count): + arg_list, kwargs = args[index] + flash_decode_stage2(*arg_list, **kwargs) + + graph.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + + cost_time = start_event.elapsed_time(end_event=end_event) + + logger.info(f"bf16 {seq_len} cost time: {cost_time} ms") + return cost_time + + +def worker( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, + test_configs, + queue, +): + dog = Watchdog(timeout=10) + dog.start() + + try: + for index in range(len(test_configs)): + tuning_config = test_configs[index] + cost_time = test_decode_attentions( + block_seq=block_seq, + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + test_count=test_count, + **tuning_config, + ) + dog.heartbeat() + queue.put(cost_time) + except Exception as ex: + logger.error(str(ex) + f" config {tuning_config} batch_size {batch_size} seq_len {seq_len} dtype {dtype}") + import sys + import traceback + + traceback.print_exc() + sys.exit(-1) + pass + + +def get_test_configs(split_id, split_count): + index = 0 + for block_n in [16, 32, 64]: + for num_warps in [ + 2, + 4, + 8, + 16, + ]: + for num_stages in [ + 1, + 2, + 3, + 4, + 5, + 7, + 9, + 10, + 11, + ]: + t_config = { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + if index % split_count == split_id: + yield t_config + index += 1 + else: + index += 1 + + +def tuning_configs( + device_id: int, + device_count: int, + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, +): + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) + best_config, best_cost_time = None, 10000000 + queue = mp.Queue() + test_configs = [] + for t_config in get_test_configs(device_id, device_count): + test_configs.append(t_config) + if len(test_configs) < 64: + continue + + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + while len(test_configs) != 0: + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + logger.info(f"{best_config} best cost: {best_cost_time}") + return best_config, best_cost_time + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + from lightllm.utils.tuning_utils import mp_tuning + import collections + + block_seq = 256 + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + num_heads = 32 + kv_head_num = 8 + q_head_dim = 128 + gqa_group_size = num_heads // kv_head_num + + store_json_ans = collections.defaultdict(dict) + + for seq_len in seq_lens: + for batch_size in batch_sizes: + ans = mp_tuning( + tuning_configs, + { + "block_seq": block_seq, + "batch_size": batch_size, + "seq_len": seq_len, + "dtype": torch.bfloat16, + "test_count": 1, + }, + ) + store_json_ans[seq_len][batch_size] = ans + + GQADiverseDecodeStage2KernelConfig.save_config( + gqa_group_size=gqa_group_size, + q_head_dim=q_head_dim, + block_seq=block_seq, + out_dtype=str(torch.bfloat16), + config_json=store_json_ans, + ) diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py similarity index 85% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py index ac18ffb955..a01bbf32d8 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py @@ -1,7 +1,5 @@ import pytest -pytest.skip(reason="need install lightllmKernel", allow_module_level=True) - import torch from lightllm.utils.light_utils import light_ops @@ -42,31 +40,32 @@ def __init__( # @pytest.mark.parametrize("shared_seq_len", [512]) @pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550]) -def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len): +@pytest.mark.parametrize("batch_size", list(range(6, 121, 6))) +def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size): """ - 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding + 测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding 与 ppl_int8kv_flash_decoding (baseline) 的对比。 """ - from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding as baseline_attention, ) - batch_size = 6 num_heads = 32 kv_head_num = 8 mark_shared_group_size = 3 - seq_len = 1024 + seq_len = 3547 head_dim = 128 quant_group_size = 8 + max_len_in_batch = 8192 test_dtype = torch.bfloat16 # 创建测试数据 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim) + kv_scale_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") @@ -77,7 +76,9 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le cache_v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") / 100.0 - req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + req_to_tokens = torch.arange(0, max_len_in_batch * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, max_len_in_batch + ) for i in range(batch_size): if i % mark_shared_group_size != 0: req_to_tokens[i, :shared_seq_len] = req_to_tokens[i - 1, :shared_seq_len] @@ -91,7 +92,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) baseline_infer_state = MockInferState( batch_size=batch_size, - max_kv_seq_len=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -100,7 +101,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 diverse 的 infer_state diverse_infer_state = MockInferState( batch_size=batch_size, - max_kv_seq_len=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -129,7 +130,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le alloc_tensor_func=alloc_tensor_func, ) - print(f"\nshared_seq_len={shared_seq_len}") + print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}") print(f"baseline_out: {baseline_out[0, 0, :4]}") print(f"diverse_out: {diverse_out[0, 0, :4]}") print(f"max diff: {(baseline_out - diverse_out).abs().max()}") diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py similarity index 53% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py index 5ef36e38e2..f3cb8de463 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py @@ -1,41 +1,48 @@ import pytest import torch -from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage1 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage1 import ( flash_decode_stage1, ) -@pytest.fixture -def setup_tensors(): - batch_size = 4 - num_heads = 4 - kv_head_num = 1 - seq_len = 256 +def create_tensors( + batch_size=4, + num_heads=4, + kv_head_num=1, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): head_dim = 128 - max_len_in_batch = seq_len block_seq = 256 - max_batch_group_size = 4 quant_group_size = 8 test_dtype = torch.bfloat16 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") b_shared_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" ) mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" ) return { @@ -56,6 +63,11 @@ def setup_tensors(): } +@pytest.fixture +def setup_tensors(): + return create_tensors() + + def test_flash_decode_stage1_execution(setup_tensors): flash_decode_stage1( q=setup_tensors["q"], @@ -106,3 +118,71 @@ def test_flash_decode_stage1_execution(setup_tensors): assert torch.allclose( setup_tensors["mid_out_logsumexp"], true_mid_out_logsumexp, atol=1e-2 ), "LogSumExp output does not match expected values" + + +def autotune_and_benchmark(): + import triton + + batch_sizes = [8, 16, 32, 64] + seq_lens = [1024, 2048, 4096] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + batch_size=batch, + num_heads=32, + kv_head_num=8, + seq_len=seq, + max_len_in_batch=8192, + max_batch_group_size=8, + kv_len=seq, + req_to_tokens_len=seq, + ) + + def fn_triton(st=setup_tensors): + return flash_decode_stage1( + q=st["q"], + k=st["k"], + k_scale=st["k_scale"], + v=st["v"], + v_scale=st["v_scale"], + Req_to_tokens=st["Req_to_tokens"], + B_req_idx=st["B_req_idx"], + b_shared_seq_len=st["b_shared_seq_len"], + b_mark_shared_group=st["b_mark_shared_group"], + max_len_in_batch=st["max_len_in_batch"], + mid_out=st["mid_out"], + mid_out_logsumexp=st["mid_out_logsumexp"], + block_seq=st["block_seq"], + max_batch_group_size=st["max_batch_group_size"], + ) + + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12}") + print(f"{'-'*80}") + for r in results: + print(f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f}") + print(f"{'='*80}") + + +if __name__ == "__main__": + autotune_and_benchmark() diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 0000000000..c7d4442543 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,293 @@ +import pytest +import torch +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, +) + + +def create_tensors( + shared_seq_len, + batch_size=4, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + block_seq = 256 + quant_group_size = 8 + + test_dtype = torch.bfloat16 + + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) + + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" + ) + + return { + "q": q, + "k": k, + "k_scale": k_scale, + "v": v, + "v_scale": v_scale, + "Req_to_tokens": Req_to_tokens, + "B_req_idx": B_req_idx, + "b_seq_len": b_seq_len, + "b_shared_seq_len": b_shared_seq_len, + "b_mark_shared_group": b_mark_shared_group, + "max_len_in_batch": max_len_in_batch, + "mid_out": mid_out, + "mid_out_logsumexp": mid_out_logsumexp, + "block_seq": block_seq, + "max_batch_group_size": max_batch_group_size, + "head_dim": head_dim, + } + + +@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) +def test_flash_decode_stage2_execution(shared_seq_len): + setup_tensors = create_tensors(shared_seq_len) + + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ + "block_seq" + ] + mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] + mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] + + q = setup_tensors["q"] + k = setup_tensors["k"] + v = setup_tensors["v"] + true_mid_out = torch.zeros_like(mid_out) + true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) + new_q = q + new_k = k.to(q.dtype) + new_v = v.to(q.dtype) + + b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] + req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] + + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( + flash_decode_stage1 as gqa_flash_decode_stage1, + ) + + gqa_flash_decode_stage1( + q=new_q, + k=new_k, + v=new_v, + Req_to_tokens=req_to_tokens, + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=b_seq_len, + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=true_mid_out, + mid_out_logsumexp=true_mid_out_logsumexp, + block_seq=setup_tensors["block_seq"], + ) + print(f"\nshared_seq_len={shared_seq_len}") + print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") + print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") + abs_diff = (mid_out - true_mid_out).abs() + max_diff = abs_diff.max() + max_diff_idx = abs_diff.argmax() + max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) + mid_out_value = mid_out[max_diff_idx_unraveled] + true_mid_out_value = true_mid_out[max_diff_idx_unraveled] + print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") + + assert torch.allclose( + mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 + ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" + assert torch.allclose( + mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 + ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" + + +if __name__ == "__main__": + import importlib + import triton + from lightllm.utils.light_utils import light_ops + + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + shared_seq_len=0, + batch_size=batch, + seq_len=seq, + max_len_in_batch=8192, + kv_len=seq, + req_to_tokens_len=seq, + ) + + # Outputs for CUDA implementation + mid_out_cuda = setup_tensors["mid_out"].clone() + mid_out_logsumexp_cuda = setup_tensors["mid_out_logsumexp"].clone() + + # Outputs for Triton implementation + mid_out_triton = setup_tensors["mid_out"].clone() + mid_out_logsumexp_triton = setup_tensors["mid_out_logsumexp"].clone() + + # Run CUDA to get reference + light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + mid_out_cuda, + mid_out_logsumexp_cuda, + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + + # Run Triton + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=mid_out_triton, + mid_out_logsumexp=mid_out_logsumexp_triton, + block_seq=setup_tensors["block_seq"], + ) + + # Compare results + diff_mid_out = torch.abs(mid_out_cuda - mid_out_triton) + diff_logsumexp = torch.abs(mid_out_logsumexp_cuda - mid_out_logsumexp_triton) + max_diff_out = diff_mid_out.max().item() + max_diff_logsumexp = diff_logsumexp.max().item() + mean_diff_out = diff_mid_out.mean().item() + mean_diff_logsumexp = diff_logsumexp.mean().item() + + cos_sim_out = torch.nn.functional.cosine_similarity( + mid_out_cuda.flatten(), mid_out_triton.flatten(), dim=0 + ).item() + cos_sim_logsumexp = torch.nn.functional.cosine_similarity( + mid_out_logsumexp_cuda.flatten(), mid_out_logsumexp_triton.flatten(), dim=0 + ).item() + + print(f"\n[batch={batch}, seq={seq}] Consistency check:") + print(" mid_out:") + print(f" max_diff: {max_diff_out:.6f}, mean_diff: {mean_diff_out:.6f}, cosine_sim: {cos_sim_out:.8f}") + print(" logsumexp:") + print( + f" max_diff: {max_diff_logsumexp:.6f}, " + f"mean_diff: {mean_diff_logsumexp:.6f}, " + f"cosine_sim: {cos_sim_logsumexp:.8f}" + ) + + # Performance + fn_cuda = lambda: light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + setup_tensors["mid_out"], + setup_tensors["mid_out_logsumexp"], + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + ms_cuda = triton.testing.do_bench_cudagraph(fn_cuda, rep=100) + + fn_triton = lambda: flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + "cuda_ms": ms_cuda, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12} {'cuda_ms':<12} {'vs cuda':<10}") + print(f"{'-'*80}") + for r in results: + vs_cuda = f"{r['cuda_ms']/r['triton_ms']:.2f}x" + emoji = "🎉" if r["triton_ms"] < r["cuda_ms"] else "" + print( + f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f} {r['cuda_ms']:<12.3f}" + f"{vs_cuda:<10} {emoji}" + ) + print(f"{'='*80}") diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py similarity index 96% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py index 18550982b9..c1a0ca1e58 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py @@ -1,6 +1,6 @@ import pytest import torch -from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage3 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage3 import ( flash_diverse_decode_stage3, ) diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py deleted file mode 100644 index cde7734817..0000000000 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest - -pytest.skip(reason="need install lightllmkernel", allow_module_level=True) - -import torch -from lightllm.utils.light_utils import light_ops - - -def create_tensors(shared_seq_len): - batch_size = 4 - num_heads = 32 - kv_head_num = 8 - seq_len = 256 - head_dim = 128 - max_len_in_batch = seq_len - block_seq = 256 - max_batch_group_size = 4 - quant_group_size = 8 - - test_dtype = torch.bfloat16 - - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) - - q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") - k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) - B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") - b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") - b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") - b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") - mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" - ) - mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" - ) - - return { - "q": q, - "k": k, - "k_scale": k_scale, - "v": v, - "v_scale": v_scale, - "Req_to_tokens": Req_to_tokens, - "B_req_idx": B_req_idx, - "b_seq_len": b_seq_len, - "b_shared_seq_len": b_shared_seq_len, - "b_mark_shared_group": b_mark_shared_group, - "max_len_in_batch": max_len_in_batch, - "mid_out": mid_out, - "mid_out_logsumexp": mid_out_logsumexp, - "block_seq": block_seq, - "max_batch_group_size": max_batch_group_size, - "head_dim": head_dim, - } - - -@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) -def test_flash_decode_stage2_execution(shared_seq_len): - setup_tensors = create_tensors(shared_seq_len) - - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - setup_tensors["mid_out"], - setup_tensors["mid_out_logsumexp"], - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ - "block_seq" - ] - mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] - mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] - - q = setup_tensors["q"] - k = setup_tensors["k"] - v = setup_tensors["v"] - true_mid_out = torch.zeros_like(mid_out) - true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) - new_q = q - new_k = k.to(q.dtype) - new_v = v.to(q.dtype) - - b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] - req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] - - from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( - flash_decode_stage1 as gqa_flash_decode_stage1, - ) - - gqa_flash_decode_stage1( - q=new_q, - k=new_k, - v=new_v, - Req_to_tokens=req_to_tokens, - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=b_seq_len, - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=true_mid_out, - mid_out_logsumexp=true_mid_out_logsumexp, - block_seq=setup_tensors["block_seq"], - ) - print(f"\nshared_seq_len={shared_seq_len}") - print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") - print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") - abs_diff = (mid_out - true_mid_out).abs() - max_diff = abs_diff.max() - max_diff_idx = abs_diff.argmax() - max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) - mid_out_value = mid_out[max_diff_idx_unraveled] - true_mid_out_value = true_mid_out[max_diff_idx_unraveled] - print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") - - assert torch.allclose( - mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 - ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" - assert torch.allclose( - mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 - ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" From 3e2e030a37f0d441c5150b32365717a1912d5b09 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 15 Jan 2026 05:49:32 +0000 Subject: [PATCH 24/43] refactor quantization (draft) --- .../layer_weights/meta_weights/norm_weight.py | 20 +- .../layer_weights/meta_weights/platform_op.py | 31 +- .../triton_kernel/dequantize_gemm_int4.py | 649 ------------------ .../triton_kernel/dequantize_gemm_int8.py | 209 ------ .../triton_kernel/{ => norm}/layernorm.py | 0 .../triton_kernel/{ => norm}/qk_norm.py | 0 .../triton_kernel/{ => norm}/rmsnorm.py | 0 .../triton_kernel/quantization}/__init__.py | 0 .../{ => quantization}/bmm_scaled_fp8.py | 0 .../quantization}/fp8act_quant_kernel.py | 0 .../fp8w8a8_block_gemm_kernel.py | 0 .../fp8w8a8_block_quant_kernel.py | 0 .../fp8w8a8_scaled_mm_per_token_kernel.py | 0 .../q_per_head_fp8_quant.py | 0 .../triton_kernel/quantize_gemm_int8.py | 376 ---------- lightllm/common/quantization/__init__.py | 105 +-- .../common/quantization/{types => }/awq.py | 0 lightllm/common/quantization/backend.py | 82 --- .../quantization/{types => }/fp8_block128.py | 0 .../quantization/{types => }/fp8_per_token.py | 0 .../quantization/{types => }/no_quant.py | 0 .../quantization/triton_quant/fp8/__init__.py | 0 .../common/quantization/types/__init__.py | 13 - .../common/quantization/{types => }/w8a8.py | 41 +- .../layer_weights/transformer_layer_weight.py | 1 - lightllm/server/api_cli.py | 12 +- 26 files changed, 76 insertions(+), 1463 deletions(-) delete mode 100644 lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py delete mode 100644 lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py rename lightllm/common/basemodel/triton_kernel/{ => norm}/layernorm.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => norm}/qk_norm.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => norm}/rmsnorm.py (100%) rename lightllm/common/{quantization/triton_quant => basemodel/triton_kernel/quantization}/__init__.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => quantization}/bmm_scaled_fp8.py (100%) rename lightllm/common/{quantization/triton_quant/fp8 => basemodel/triton_kernel/quantization}/fp8act_quant_kernel.py (100%) rename lightllm/common/{quantization/triton_quant/fp8 => basemodel/triton_kernel/quantization}/fp8w8a8_block_gemm_kernel.py (100%) rename lightllm/common/{quantization/triton_quant/fp8 => basemodel/triton_kernel/quantization}/fp8w8a8_block_quant_kernel.py (100%) rename lightllm/common/{quantization/triton_quant/fp8 => basemodel/triton_kernel/quantization}/fp8w8a8_scaled_mm_per_token_kernel.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => quantization}/q_per_head_fp8_quant.py (100%) delete mode 100644 lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py rename lightllm/common/quantization/{types => }/awq.py (100%) delete mode 100644 lightllm/common/quantization/backend.py rename lightllm/common/quantization/{types => }/fp8_block128.py (100%) rename lightllm/common/quantization/{types => }/fp8_per_token.py (100%) rename lightllm/common/quantization/{types => }/no_quant.py (100%) delete mode 100644 lightllm/common/quantization/triton_quant/fp8/__init__.py delete mode 100644 lightllm/common/quantization/types/__init__.py rename lightllm/common/quantization/{types => }/w8a8.py (75%) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 73b937b776..df12ec9b1b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -34,12 +34,12 @@ def _native_forward( variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + eps) x = (x * self.weight).to(self.data_type_) - if out is None: + if out is not None: out.copy_(x) return out return x - def _cuda_forward( + def _triton_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: assert input.ndim == 2 and self.weight.ndim == 1 @@ -47,6 +47,12 @@ def _cuda_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) + def _cuda_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # only triton implementation is supported for rmsnorm on cuda platform + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: @@ -80,12 +86,12 @@ def _native_forward( x = torch.nn.functional.layer_norm( input, normalized_shape=[self.dim], weight=self.weight, bias=self.bias, eps=eps ) - if out is None: + if out is not None: out.copy_(x.to(self.data_type_)) return out return x.to(self.data_type_) - def _cuda_forward( + def _triton_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: assert input.ndim == 2 and self.weight.ndim == 1 @@ -93,6 +99,12 @@ def _cuda_forward( out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return layernorm_forward(x=input, weight=self.weight, bias=self.bias, eps=eps, out=out) + def _cuda_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # only triton implementation is supported for layernorm on cuda platform + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py index 127a543b25..1ba1610fc9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/platform_op.py @@ -3,6 +3,9 @@ from typing import Optional, Callable, Any from lightllm.utils.device_utils import get_platform, Platform from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class PlatformAwareOp(ABC): @@ -14,10 +17,12 @@ class PlatformAwareOp(ABC): def __init__(self): args = get_env_start_args() self.platform = get_platform(args.hardware_platform) - self.enable_torch_naive = args.enable_torch_naive + self.enable_torch_fallback = args.enable_torch_fallback + self.enable_triton_fallback = args.enable_triton_fallback self._forward = self._route_forward() def _route_forward(self) -> Callable: + method_name_map = { Platform.CUDA: "_cuda_forward", Platform.ASCEND: "_ascend_forward", @@ -33,14 +38,23 @@ def _route_forward(self) -> Callable: if callable(method): return method - if self.enable_torch_naive: + if self.enable_triton_fallback: + if hasattr(self, "_triton_forward"): + return self._triton_forward + logger.warning( + f"No triton implementation found for {self.__class__.__name__} on {self.platform.name} platform. " + f"Please implement {self.__class__.__name__}_{self.platform.name}_triton_forward method, " + f"or set --enable_torch_fallback to use default implementation." + ) + + if self.enable_torch_fallback: return self._native_forward - # 如果都没有,抛出异常 + # if no implementation found, raise error raise NotImplementedError( - f"No implementation found for platform {self.platform.name}. " - f"Please implement _{self.platform.name}_forward method, " - f"or set --enable_torch_naive to use default implementation." + f"No implementation found for {self.__class__.__name__} on {self.platform.name} platform. " + f"Please implement {self.__class__.__name__}_{self.platform.name}_forward method, " + f"or set --enable_torch_fallback to use default implementation." ) @abstractmethod @@ -50,3 +64,8 @@ def _native_forward(self, *args, **kwargs) -> Any: @abstractmethod def _cuda_forward(self, *args, **kwargs) -> Any: raise NotImplementedError("cuda forward must implement this method") + + # Since Triton may be compatible with all hardware platforms in the future, + # so provide triton implementation as a fallback for all hardware platforms + def _triton_forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("triton forward must implement this method") diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py deleted file mode 100644 index 143d93b230..0000000000 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py +++ /dev/null @@ -1,649 +0,0 @@ -import time - -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - ], - key=['M', 'N', 'K', 'NO_GROUPS'], -) -@triton.jit -def matmul4_kernel( - a_ptr, b_ptr, c_ptr, - scales_ptr, zeros_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_scales_g, stride_scales_n, - stride_zeros_g, stride_zeros_n, - groupsize, NO_GROUPS: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - Compute the matrix multiplication C = A x B. - A is of shape (M, K) float16 - B is of shape (K//8, N) int32 - C is of shape (M, N) float16 - scales is of shape (G, N) float16 - zeros is of shape (G, N//8) int32 - groupsize is an int specifying the size of groups for scales and zeros. - G is K // groupsize. - Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient. - WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K. - WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N. - WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K. - """ - bits = 4 - infearure_per_bits = 8 - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - a_mask = (offs_am[:, None] < M) - # b_ptrs is set up such that it repeats elements along the K axis 8 times - b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) - scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,) - # zeros_ptrs is set up such that it repeats elements along the N axis 8 times - zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,) - # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros - shifter = (offs_k % infearure_per_bits) * bits - zeros_shifter = (offs_bn % infearure_per_bits) * bits - # If G == 1, scales and zeros are the same for all K, so we can load them once - if NO_GROUPS: - # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop - scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,) - zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 - # Unpack zeros - zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 - # zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16 - zeros = zeros * scales - # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N) - # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension - # So this loop is along the infeatures dimension (K) - # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, num_pid_k): - a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated - if not NO_GROUPS: - g_id = k // (groupsize // BLOCK_SIZE_K) - ptr = scales_ptrs + g_id * stride_scales_g - scales = tl.load(ptr) # (BLOCK_SIZE_N,) - ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,) - zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 - # Unpack zeros - zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32 - zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16 - # Now we need to unpack b (which is 4-bit values) into 32-bit values - b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values - b = b * scales[None, :] - zeros[None, :] # Scale and shift - # print("data type", a, b) - accumulator += tl.dot(a, b.to(a.dtype)) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk - c = accumulator.to(c_ptr.dtype.element_ty) - # Store the result - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor: - """ - Compute the matrix multiplication C = A x B + bias. - Where B is quantized using GPTQ and groupsize = -1 into 4-bit values. - - A is of shape (..., K) float16 - qweight is of shape (K//8, N) int32 - scales is of shape (G, N) float16 - qzeros is of shape (G, N//8) int32 - bias is of shape (1, N) float16 - - groupsize is the number of infeatures in each group. - G = K // groupsize - - Returns C of shape (..., N) float16 - """ - assert x.shape[-1] == (qweight.shape[0] * 8), "A must be a multiple of 8 in the last dimension" - assert x.is_contiguous(), "A must be contiguous" - - M, K = x.shape - N = qweight.shape[1] - # This is based on the possible BLOCK_SIZE_Ks - # assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, "K must be a multiple of 16, 32, 64, and 128" - # # This is based on the possible BLOCK_SIZE_Ns - # assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, "N must be a multiple of 16, 32, 64, 128, and 256" - # # This is based on the possible BLOCK_SIZE_Ks - # assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, "groupsize must be a multiple of 32, 64, and 128" - - # output = torch.empty((M, N), device='cuda', dtype=torch.float16) - if output is None: - inplace = False - output = torch.empty((M, N), device=x.device, dtype=x.dtype) - else: - inplace = True - - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - matmul4_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), - group_size, group_size == K, - ) - # return output - if not inplace: - return output - - -@triton.autotune( - configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - - ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] -) -@triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - bs_ptr, bzp_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, - group_size, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr - ): - """ - assert K % (BLOCK_SIZE_K * SPLIT_K) == 0 - """ - pid = tl.program_id(axis=0) - pid_sp_k = tl.program_id(axis=1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - # [BLOCK_M, BLOCK_K] - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - # [BLOCK_K, BLOCK_N] but repeated 8 times in N - b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn - # tl.static_print("shape", a_ptrs, b_ptrs, bs_ptrs, bzp_ptrs) - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B. - # [BLOCK_K, BLOCK_N] but repeated group_size times in K - bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \ - + offs_bn[None, :] * stride_bsn - # [BLOCK_K, BLOCK_N] but repeated in K and N - bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \ - + (offs_bn[None, :] // 8) * stride_bzpn - b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0 - bzp_shift_bits = (offs_bn[None, :] % 8) * 4 - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - bs = tl.load(bs_ptrs) - bzp = tl.load(bzp_ptrs) - # We accumulate along the K dimension. - int_b = (b >> b_shift_bits) & 0xF - int_bzp = (bzp >> bzp_shift_bits) & 0xF - b = ((int_b - int_bzp) * bs).to(a.dtype) - accumulator += tl.dot(a, b.to(a.dtype)) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0 - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - c = accumulator.to(c_ptr.dtype.element_ty) - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -def matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor: - """ - """ - assert x.is_contiguous(), "A must be contiguous" - assert qweight.is_contiguous(), "B must be contiguous" - M, K = x.shape - N = scales.shape[1] - if output is None: - output = torch.zeros((M, N), device=x.device, dtype=x.dtype) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], - ) - matmul_kernel[grid]( - x, qweight, output, - scales, qzeros, - M, N, K, - x.stride(0), x.stride(1), - qweight.stride(0), qweight.stride(1), - output.stride(0), output.stride(1), - scales.stride(0), scales.stride(1), - qzeros.stride(0), qzeros.stride(1), - group_size, - ) - return output - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - ], - key=['K', 'N'], -) -@triton.jit -def dequantize_kernel( - # Pointers to matrices - b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr, - # Matrix dimensions - K, N, group_size, - stride_bk, stride_bn, - stride_bsk, stride_bsn, - stride_bzpk, stride_bzpn, - stride_fpbk, stride_fpbn, - # Meta-parameters - BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, -): - """Dequantize tile [BLOCK_SIZE_K, BLOCK_SIZE_N] in full precision. - We should assert BLOCK_SIZE_N % 8 == 0. - weight[K // 8, N], scale[K // group_size, N], zp[K // group_size, N // group_size] - """ - k_block_idx = tl.program_id(axis=0) - n_block_idx = tl.program_id(axis=1) - offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_n = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - fpb_offs = offs_k[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn - b_offs = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn - bzp_offs = (offs_k[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn - bs_offs = (offs_k[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn - n_mask = offs_n[None, :] < N - k_mask = offs_k[:, None] < K - mask = n_mask & k_mask - int32_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0) - zp_b = tl.load(b_zp_ptr + bzp_offs, mask=mask, other=0.0) - scale_b = tl.load(b_scale_ptr + bs_offs, mask=mask, other=0.0) - b_shift = (offs_k[:, None] % 8) * 4 - bzp_shift = (offs_n[None, :] % 8) * 4 - fp_weight = (((int32_b >> b_shift) & 0xF) - ((zp_b >> bzp_shift) & 0xF)) * scale_b - tl.store(fpb_ptr + fpb_offs, fp_weight, mask=mask) - - -def dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size): - Kw, N = b.shape - K = Kw * 8 - fp_b = torch.ones((K, N), device=device, dtype=dtype) - grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), - triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - dequantize_kernel[grid]( - b, b_scale, b_zero_point, fp_b, - K, N, group_size, - b.stride(0), b.stride(1), - b_scale.stride(0), b_scale.stride(1), - b_zero_point.stride(0), b_zero_point.stride(1), - fp_b.stride(0), fp_b.stride(1) - ) - return fp_b - - -def matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, group_size=128, out=None): - """ - Matmul dequantize int4 s1 dequantize weight to `fp_b` and do fp16 torch.mm, - this is for `prefill` stage, since weight size is fixed so is dequantize overhead, - perfill stage have more tokens to amortize dequant cost. - """ - assert a.is_contiguous(), "Matrix A must be contiguous" - # assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - Kw, N = b.shape - if out is None: - # Allocates output. - out = torch.empty((M, N), device=a.device, dtype=a.dtype) - fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size) - torch.mm(a, fp_b, out=out) - fp_b = None - return out - - -def quantize_int4(weight, group_size=128, tp_rank=0): - # Weight shape: [H1 // 8, H2] - # Scale shape: [H1 // group_size, H2] - # zero_pint shape: [H1 // group_size, H2 // 8] - - weight = weight.transpose(1, 0) - h1, h2 = weight.shape - assert h1 % 8 == 0 and h2 % 8 == 0, "H1 {} H2 {}".format(h1, h2) - assert h2 % group_size == 0, "H1 {} H2 {}".format(h1, h2) - weight = weight.contiguous().view(-1, group_size).cuda(tp_rank) - weight_max = weight.amax(-1, keepdim=True) - weight_max = torch.where(weight_max < 0, 0, weight_max) - weight_min = weight.amin(-1, keepdim=True) - weight_min = torch.where(weight_min > 0, 0, weight_min) - weight_range = weight_max - weight_min - scale = weight_range / (2 ** 4 - 1) - zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32) - weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2) - int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device) - int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device) - zero_point = zero_point.view(h1, -1) - scale = scale.view(h1, -1) - # pack 8 int4 in an int32 number. - # Weight pack in row. - for pack in range(0, h2, 8): - for i in range(8): - int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4) - # zero point pack in col. - for pack in range(0, h1, 8): - for i in range(8): - int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4) - ''' - fp_weight = torch.zeros(h1, h2).half().to(weight.device) - for pack in range(0, h1 // 8): - for i in range(8): - fp_weight[pack * 8 + i, :] = \ - ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16 - print((fp_weight - weight).abs().sum()) - - fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device) - for pack in range(0, h1 // 8): - for i in range(8): - fp_zp[pack * 8 + i, :] = \ - (int_zero_point[pack, :] >> (i * 4)) & 15 - - print((fp_zp - zero_point).abs().sum()) - ''' - weight = None - return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size - - -def unpack_int4(weight, scale, zp): - """ - Test function to verify quantize int4 is correct. - Will not be used in model inference. - """ - weight = weight.transpose(1, 0) - scale = scale.transpose(1, 0) - zp = zp.transpose(1, 0) - h1, h2 = weight.shape - group_size = h2 * 8 // scale.shape[1] - group_num = scale.shape[1] - fp_weight = torch.zeros(h1, h2 * 8).half().to(weight.device) - fp_zero_point = torch.zeros(h1, group_num).to(weight.device) - for pack in range(0, h2): - for i in range(8): - fp_weight[:, pack * 8 + i] = (weight[:, pack] >> (i * 4)) & 0xF - for pack in range(0, h1 // 8): - for i in range(8): - fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF - for g in range(group_num): - fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \ - fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1) - return fp_weight.transpose(1, 0) - - -def test_int4(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b) - for _ in range(10): - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point) - torch.cuda.synchronize() - t2 = time.time() - triton_time = t2 - t1 - print("Triton time cost", (t2 - t1)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - print("Torch time cost", (t2 - t1)) - return triton_time, torch_time - - -def test_correct_int4_s1(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = dequantize_int4(int_b, b_scale, b_zero_point, a.device, a.dtype, group_size) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_s1(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_correct_int4_s2(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = unpack_int4(int_b, b_scale, b_zero_point) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_s2(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_correct_int4_gptq(M=32, K=4096, N=4096): - group_size = 128 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale, b_zero_point, _ = quantize_int4(b, group_size=group_size) - cos = torch.nn.CosineSimilarity(0) - fp_weight = unpack_int4(int_b, b_scale, b_zero_point) - print("Quantize cos", cos(fp_weight.flatten().to(torch.float32), b.flatten().to(torch.float32))) - triton_output = matmul_dequantize_int4_gptq(a, int_b, b_scale, b_zero_point, group_size) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[4, 8, 16, 32, 64, 128] + [ - 128 * i for i in range(2, 33, 2) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton-s1', 'dequantize', 'triton-s2', 'triton-gptq'], - # Label name for the lines - line_names=["cuBLAS", "Triton-s1", "Dequant(GB/s)", "Triton-s2", "Triton-gptq"], - # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-'), ('yellow', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) -def benchmark(M, provider): - K = 4096 - N = 4096 - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s1': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s1(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-s2': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_s2(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'dequantize': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: dequantize_int4(intb, b_scale, bzp, 'cuda', torch.float16, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) - if provider == 'triton-gptq': - intb, b_scale, bzp, _ = quantize_int4(b, group_size=64) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int4_gptq(a, intb, b_scale, bzp, 64), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - return perf(ms), perf(max_ms), perf(min_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - t1, t2 = test_int4(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int4(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - print("Triton time {} Torch time {}".format(st1, st2)) - - -if __name__ == "__main__": - # test_correct_int4_s1() - # test_correct_int4_s2() - # test_correct_int4_gptq() - benchmark.run(show_plots=True, print_data=True) - exit() - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) diff --git a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py deleted file mode 100644 index e2c5c0dc95..0000000000 --- a/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int8.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), - ], - key=['K', 'N'], -) - - -@triton.jit -def dequantize_kernel( - # Pointers to matrices - b_ptr, b_scale_ptr, fpb_ptr, - # Matrix dimensions - K, N, - stride_bk, stride_bn, - stride_fpbk, stride_fpbn, - # Meta-parameters - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - k_block_idx = tl.program_id(axis=0) - n_block_idx = tl.program_id(axis=1) - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_n = tl.arange(0, BLOCK_SIZE_N) - b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn - fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \ - (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn - bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] - n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N - mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask - int_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0) - scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask, other=0.0) - tl.store(fpb_ptr + fpb_offs, int_b * scale_b, mask=mask) - - -def matmul_dequantize_int8(a, b, b_scale, out=None): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - # assert b.is_contiguous(), "Matrix B must be contiguous" - M, K = a.shape - K, N = b.shape - if out == None: - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - else: - c = out - fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype) - grid = lambda META: ( - triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - dequantize_kernel[grid]( - b, b_scale, fp_b, - K, N, - b.stride(0), b.stride(1), - fp_b.stride(0), fp_b.stride(1) - ) - torch.mm(a, fp_b, out=c) - return c - - -def quantize_int8(weight, axis=0, tp_rank=0): - # Weight shape: [H1, H2] - # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. - weight = (weight / scale).to(torch.int8) - if axis == 0: - weight = weight.t().contiguous().t() - scale = scale.squeeze(axis) - return weight.contiguous().cuda(tp_rank), scale.contiguous().cuda(tp_rank) - - -def test_int8(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale = quantize_int8(b) - for _ in range(10): - triton_output = matmul_dequantize_int8(a, int_b, b_scale.unsqueeze(0)) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - triton_output = matmul_dequantize_int8(a, int_b, b_scale.unsqueeze(0)) - torch.cuda.synchronize() - t2 = time.time() - triton_time = t2 - t1 - print("Triton time cost", (t2 - t1)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - print("Torch time cost", (t2 - t1)) - return triton_time, torch_time - - -def test_correct_int8(M=512, K=4096, N=4096): - import time - - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_b, b_scale = quantize_int8(b) - cos = torch.nn.CosineSimilarity(0) - triton_output = matmul_dequantize_int8(a, int_b, b_scale) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - print("Output cos ", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i for i in range(1, 33) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton'], - # Label name for the lines - line_names=["cuBLAS", "Triton"], - # Line styles - styles=[('green', '-'), ('blue', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) - - -def benchmark(M, N, K, provider): - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - if provider == 'triton': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - intb, b_scale = quantize_int8(b) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_dequantize_int8(a, intb, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - return perf(ms), perf(min_ms), perf(max_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - t1, t2 = test_int8(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - t1, t2 = test_int8(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - print("Triton time {} Torch time {}".format(st1, st2)) - - -if __name__ == "__main__": - test_correct_int8() - benchmark.run(show_plots=True, print_data=True) - - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) \ No newline at end of file diff --git a/lightllm/common/basemodel/triton_kernel/layernorm.py b/lightllm/common/basemodel/triton_kernel/norm/layernorm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/layernorm.py rename to lightllm/common/basemodel/triton_kernel/norm/layernorm.py diff --git a/lightllm/common/basemodel/triton_kernel/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/qk_norm.py rename to lightllm/common/basemodel/triton_kernel/norm/qk_norm.py diff --git a/lightllm/common/basemodel/triton_kernel/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/rmsnorm.py rename to lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py diff --git a/lightllm/common/quantization/triton_quant/__init__.py b/lightllm/common/basemodel/triton_kernel/quantization/__init__.py similarity index 100% rename from lightllm/common/quantization/triton_quant/__init__.py rename to lightllm/common/basemodel/triton_kernel/quantization/__init__.py diff --git a/lightllm/common/basemodel/triton_kernel/bmm_scaled_fp8.py b/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/bmm_scaled_fp8.py rename to lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8act_quant_kernel.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_scaled_mm_per_token_kernel.py similarity index 100% rename from lightllm/common/quantization/triton_quant/fp8/fp8w8a8_scaled_mm_per_token_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_scaled_mm_per_token_kernel.py diff --git a/lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py b/lightllm/common/basemodel/triton_kernel/quantization/q_per_head_fp8_quant.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py rename to lightllm/common/basemodel/triton_kernel/quantization/q_per_head_fp8_quant.py diff --git a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py b/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py deleted file mode 100644 index 4f3f6a385c..0000000000 --- a/lightllm/common/basemodel/triton_kernel/quantize_gemm_int8.py +++ /dev/null @@ -1,376 +0,0 @@ -import time -import torch - -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({}, num_stages=2, num_warps=8), - triton.Config({}, num_stages=2, num_warps=4), - triton.Config({}, num_stages=2, num_warps=2), - triton.Config({}, num_stages=2, num_warps=1), - ], - key=['K'], -) -@triton.jit -def quantize_int8_perrow_kernel( - fpa_ptr, a_ptr, as_ptr, - M, K, - stride_fpam, stride_fpak, - stride_am, stride_ak, - stride_asm, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - - fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak - a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak - a_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - a_max = tl.maximum(a_max, tl.max(tl.abs(fpa), axis=1)) - fpa_ptrs += BLOCK_SIZE_K * stride_fpak - a_scale = (a_max / 127.) - fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - inta = (fpa / a_scale[:, None]).to(tl.int8) - tl.store(a_ptrs, inta, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K) - fpa_ptrs += BLOCK_SIZE_K * stride_fpak - a_ptrs += BLOCK_SIZE_K * stride_ak - as_offs = pid_m * BLOCK_SIZE_M * stride_asm + tl.arange(0, BLOCK_SIZE_M) - tl.store(as_ptr + as_offs, a_scale) - - -def quantize_int8_perrow(fpa): - a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8) - a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=fpa.dtype) - M, K = fpa.shape - BLOCK_SIZE_M = 1 - BLOCK_SIZE_K = triton.next_power_of_2(K) - grid = (M // BLOCK_SIZE_M,) - quantize_int8_perrow_kernel[grid]( - fpa, a, a_scale, - M, K, - fpa.stride(0), fpa.stride(1), - a.stride(0), a.stride(1), - a_scale.stride(0), - BLOCK_SIZE_M, BLOCK_SIZE_K, - ) - return a, a_scale - - -@triton.autotune( - configs=[ - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8), - triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4), - ], - key=['M', 'N', 'K'], - reset_to_zero=['c_ptr'] -) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_asm, - stride_bk, stride_bn, - stride_bsn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - pid_sp_k = tl.program_id(axis=1) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetics` section for details - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - as_ptrs = as_ptr + offs_am * stride_asm - bs_ptrs = bs_ptr + offs_bn * stride_bsn - a_scale = tl.load(as_ptrs, mask=offs_am < M, other=0.0) - b_scale = tl.load(bs_ptrs, mask=offs_bn < N, other=0.0) - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(c_ptr.dtype.element_ty) - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - if SPLIT_K == 1: - tl.store(c_ptrs, c, mask=c_mask) - else: - tl.atomic_add(c_ptrs, c, mask=c_mask) - - -def matmul_quantize_int8(fpa, b, b_scale, out=None): - a, a_scale = quantize_int8_perrow(fpa) - # a, a_scale = quantize_int8(fpa, axis=1) - return matmul_int8(a, a_scale, b, b_scale, out) - - -def matmul_int8(a, a_scale, b, b_scale, out=None): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - M, K = a.shape - K, N = b.shape - # Allocates output. - if out == None: - c = torch.zeros((M, N), device=a.device, dtype=torch.float16) - else: - c = out.fill_(0.) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - META['SPLIT_K'], - ) - matmul_kernel[grid]( - a, a_scale, b, b_scale, c, - M, N, K, - a.stride(0), a.stride(1), - a_scale.stride(0), - b.stride(0), b.stride(1), - b_scale.stride(0), - c.stride(0), c.stride(1), - ) - return c - - -def quantize_int8(weight, axis=0, tp_rank=0): - # Weight shape: [H1, H2] - # Scale shape: [H2] - scale = weight.abs().amax(axis, keepdim=True) / 127. - weight = (weight / scale).to(torch.int8) - # col major will accelerate i8xi8 kernel. - if axis == 0: - weight = weight.t().contiguous().t() - scale = scale.squeeze(axis) - return weight.contiguous().cuda(tp_rank), scale.contiguous().cuda(tp_rank) - - -def test_correct_int8(M=32, N=4096, K=4096): - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - int_a, scale_a = quantize_int8_perrow(a) - cos = torch.nn.CosineSimilarity(0) - print("Quantization cos", cos((int_a * scale_a.unsqueeze(1)).flatten().to(torch.float32), a.flatten().to(torch.float32))) - int_b, scale_b = quantize_int8(b, axis=0) - triton_output = matmul_int8(int_a, scale_a, int_b, scale_b) - torch_output = torch.matmul(a, b) - print(f"triton_output={triton_output}") - print(f"torch_output={torch_output}") - cos = torch.nn.CosineSimilarity(0) - print("Output cos", cos(triton_output.flatten().to(torch.float32), torch_output.flatten().to(torch.float32))) - - -def test_int8(M, K, N): - import time - - print("M: {} K: {} N: {}".format(M, K, N)) - torch.manual_seed(0) - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16).contiguous() - int_b, scale_b = quantize_int8(b, axis=0) - for _ in range(10): - # int_a, a_scale = quantize_int8(a, 1) - int_a, a_scale = quantize_int8_perrow(a) - triton_output = matmul_int8(int_a, a_scale, int_b, scale_b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - #int_a, a_scale, _ = quantize_int8(a, 1) - int_a, a_scale = quantize_int8_perrow(a) - torch.cuda.synchronize() - qt2 = time.time() - for _ in range(iters): - triton_output = matmul_int8(int_a, a_scale, int_b, scale_b) - torch.cuda.synchronize() - t2 = time.time() - quant_time = qt2 - t1 - triton_time = t2 - qt2 - triton_tflops = 2 * M * N * K * 1e-12 / (triton_time / iters) - quant_bandwith = 2 * M * K * 1e-9 / (quant_time / iters) - print("Triton time cost: {} (tflops {}) + quant: {} (bandwidth {})".format( - triton_time, triton_tflops, quant_time, quant_bandwith)) - for _ in range(10): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - iters = 512 - t1 = time.time() - for _ in range(iters): - torch_output = torch.matmul(a, b) - torch.cuda.synchronize() - t2 = time.time() - torch_time = t2 - t1 - torch_tflops = 2 * M * N * K * 1e-12 / (torch_time / iters) - print("Torch time cost: {} (tflops {})".format(t2 - t1, torch_tflops)) - return triton_time, torch_time, quant_time - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=['M'], # Argument names to use as an x-axis for the plot - x_vals=[32, 64, 128, 256] + [ - 512 * i * 2 for i in range(1, 17) - ], # Different possible values for `x_name` - line_arg='provider', # Argument name whose value corresponds to a different line in the plot - # Possible values for `line_arg` - line_vals=['cublas', 'triton-i8', 'triton-quant-i8', 'quant-perrow'], - # Label name for the lines - line_names=["cuBLAS", "Triton-i8", "Triton-Quant-i8", "Quant-perrow(GB/s)"], - # Line styles - styles=[('green', '-'), ('blue', '-'), ('red', '-'), ('purple', '-')], - ylabel="TFLOPS", # Label name for the y-axis - plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. - args={}, - ) -) -def benchmark(M, provider): - K = 10240 - N = 27392 * 2 // 8 - quantiles = [0.5, 0.2, 0.8] - if provider == 'cublas': - a = torch.randn((M, K), device='cuda', dtype=torch.float16) - b = torch.randn((K, N), device='cuda', dtype=torch.float16) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - int_a, a_scale = quantize_int8(a, axis=1) - int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_int8(int_a, a_scale, int_b, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'triton-quant-i8': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - b = torch.randn((K, N), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - int_b, b_scale = quantize_int8(b, axis=0) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul_quantize_int8(a, int_b, b_scale), quantiles=quantiles) - perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) - if provider == 'quant-perrow': - a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(torch.int8).contiguous() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: quantize_int8_perrow(a), quantiles=quantiles) - perf = lambda ms: 2 * M * K * 1e-9 / (ms * 1e-3) - return perf(ms), perf(min_ms), perf(max_ms) - - -def test_model_layer(bs, sqe_len, hidden, inter, tp): - st1 = 0 - st2 = 0 - st3 = 0 - t1, t2, t3 = test_int8(bs * sqe_len, hidden, hidden * 3 // tp) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, hidden // tp, hidden) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, hidden, inter * 2 // tp) - st1 += t1 - st2 += t2 - st3 += t3 - t1, t2, t3 = test_int8(bs * sqe_len, inter // tp, hidden) - st1 += t1 - st2 += t2 - st3 += t3 - print("Triton time {} Torch time {} Quant time {}".format(st1, st2, st3)) - - -if __name__ == "__main__": - test_correct_int8() - benchmark.run(show_plots=True, print_data=True) - - bs = 32 - hidden = 4096 - inter = 11008 - prefill_len = 512 - decode_len = 1 - tp = 1 - test_model_layer(bs, prefill_len, hidden, inter, tp) - test_model_layer(bs, decode_len, hidden, inter, tp) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index d5289298cf..8cbcc2e684 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -1,92 +1,13 @@ -import yaml -import collections -from .registry import QUANTMETHODS -from .backend import QUANT_BACKEND -from lightllm.utils.log_utils import init_logger - -# Import all type classes (they auto-register with QUANTMETHODS) -from .types import ( - NoQuantization, - FP8Block128Quantization, - FP8PerTokenQuantization, - W8A8Quantization, - AWQQuantization, -) - -# Re-export for backwards compatibility -from .types.awq import is_awq_marlin_compatible - -logger = init_logger(__name__) - - -class Quantcfg: - def __init__(self, network_config, quant_type="none", custom_cfg_path=None): - self.layer_num = network_config["n_layer"] - self.quant_type = quant_type - self.network_config_ = network_config - self._parse_custom_cfg(custom_cfg_path) - self._parse_network_config(network_config) - - def _parse_network_config(self, network_config): - hf_quantization_config = network_config.get("quantization_config", None) - if hf_quantization_config is None: - self.quantized_weight = False - self.static_activation = False - self.hf_quantization_config = None - return - self.quantized_weight = True - activation_scheme = network_config.get("activation_scheme", "dynamic") - self.static_activation = activation_scheme == "static" - self.hf_quantization_config = hf_quantization_config - self.hf_quantization_method = hf_quantization_config["quant_method"] - self._mapping_quant_method() - - def _mapping_quant_method(self): - if self.hf_quantization_method == "fp8": - block_size = self.hf_quantization_config.get("weight_block_size", None) - if block_size == [128, 128]: - self.quant_type = "fp8-block128" - logger.info( - f"Selected quant type: fp8-block128, backend: {QUANT_BACKEND.get_backend('fp8-block128').name}" - ) - else: - self.quant_type = "fp8-per-token" - logger.info( - f"Selected quant type: fp8-per-token, backend: {QUANT_BACKEND.get_backend('fp8-per-token').name}" - ) - elif self.hf_quantization_method == "awq": - self.quant_type = "awq" - logger.info("Selected quant type: awq (marlin auto-selected if compatible)") - else: - # TODO: more quant methods - raise NotImplementedError(f"Quant method {self.hf_quantization_method} not implemented yet.") - pass - - def _parse_custom_cfg(self, custom_cfg_path): - self.quant_cfg = collections.defaultdict(dict) - if custom_cfg_path is None: - return - - with open(custom_cfg_path, "r") as file: - data = yaml.safe_load(file) - - self.quant_type = data["quant_type"] - for layer_quant_cfg in data.get("mix_bits", []): - name = layer_quant_cfg["name"] - layer_nums = layer_quant_cfg.get("layer_nums", range(self.layer_num)) - layer_quant_type = layer_quant_cfg["quant_type"] - for layer_num in layer_nums: - self.quant_cfg[layer_num].update({name: layer_quant_type}) - - def get_quant_type(self, layer_num, name): - layer_config = self.quant_cfg.get(layer_num, None) - if layer_config is None: - return self.quant_type - quant_type = layer_config.get(name, self.quant_type) - return quant_type - - def get_quant_method(self, layer_num, name): - quant_type = self.get_quant_type(layer_num, name) - quant_method = QUANTMETHODS.get(quant_type) - quant_method.hf_quantization_config = self.hf_quantization_config - return quant_method +from .no_quant import NoQuantization +from .fp8_block128 import FP8Block128Quantization +from .fp8_per_token import FP8PerTokenQuantization +from .w8a8 import W8A8Quantization +from .awq import AWQQuantization + +__all__ = [ + "NoQuantization", + "FP8Block128Quantization", + "FP8PerTokenQuantization", + "W8A8Quantization", + "AWQQuantization", +] diff --git a/lightllm/common/quantization/types/awq.py b/lightllm/common/quantization/awq.py similarity index 100% rename from lightllm/common/quantization/types/awq.py rename to lightllm/common/quantization/awq.py diff --git a/lightllm/common/quantization/backend.py b/lightllm/common/quantization/backend.py deleted file mode 100644 index e6d081ec27..0000000000 --- a/lightllm/common/quantization/backend.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from enum import Enum, auto -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class BackendType(Enum): - TRITON = auto() - VLLM = auto() - DEEPGEMM = auto() - - -class BackendRegistry: - _instance = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - - self._force_triton = os.getenv("LIGHTLLM_USE_TRITON_QUANT", "0").upper() in ["1", "TRUE", "ON"] - - self._has_vllm = self._check_vllm() - self._has_deepgemm = self._check_deepgemm() - - if self._force_triton: - logger.info("LIGHTLLM_USE_TRITON_QUANT is set, forcing Triton backend for quantization") - else: - logger.info(f"Available quantization backends: vLLM={self._has_vllm}, DeepGEMM={self._has_deepgemm}") - - def _check_vllm(self) -> bool: - try: - from lightllm.utils.vllm_utils import HAS_VLLM - - return HAS_VLLM - except ImportError: - return False - - def _check_deepgemm(self) -> bool: - try: - import deep_gemm # noqa: F401 - - return True - except ImportError: - return False - - @property - def force_triton(self) -> bool: - return self._force_triton - - @property - def has_vllm(self) -> bool: - return self._has_vllm - - @property - def has_deepgemm(self) -> bool: - return self._has_deepgemm - - def get_backend(self, quant_type: str) -> BackendType: - if self._force_triton: - return BackendType.TRITON - - if quant_type == "fp8-block128": - if self._has_deepgemm: - return BackendType.DEEPGEMM - elif self._has_vllm: - return BackendType.VLLM - elif quant_type in ["w8a8", "fp8-per-token"]: - if self._has_vllm: - return BackendType.VLLM - - return BackendType.TRITON - - -QUANT_BACKEND = BackendRegistry() diff --git a/lightllm/common/quantization/types/fp8_block128.py b/lightllm/common/quantization/fp8_block128.py similarity index 100% rename from lightllm/common/quantization/types/fp8_block128.py rename to lightllm/common/quantization/fp8_block128.py diff --git a/lightllm/common/quantization/types/fp8_per_token.py b/lightllm/common/quantization/fp8_per_token.py similarity index 100% rename from lightllm/common/quantization/types/fp8_per_token.py rename to lightllm/common/quantization/fp8_per_token.py diff --git a/lightllm/common/quantization/types/no_quant.py b/lightllm/common/quantization/no_quant.py similarity index 100% rename from lightllm/common/quantization/types/no_quant.py rename to lightllm/common/quantization/no_quant.py diff --git a/lightllm/common/quantization/triton_quant/fp8/__init__.py b/lightllm/common/quantization/triton_quant/fp8/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/lightllm/common/quantization/types/__init__.py b/lightllm/common/quantization/types/__init__.py deleted file mode 100644 index 8cbcc2e684..0000000000 --- a/lightllm/common/quantization/types/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .no_quant import NoQuantization -from .fp8_block128 import FP8Block128Quantization -from .fp8_per_token import FP8PerTokenQuantization -from .w8a8 import W8A8Quantization -from .awq import AWQQuantization - -__all__ = [ - "NoQuantization", - "FP8Block128Quantization", - "FP8PerTokenQuantization", - "W8A8Quantization", - "AWQQuantization", -] diff --git a/lightllm/common/quantization/types/w8a8.py b/lightllm/common/quantization/w8a8.py similarity index 75% rename from lightllm/common/quantization/types/w8a8.py rename to lightllm/common/quantization/w8a8.py index e3b0ef592b..f803794a29 100644 --- a/lightllm/common/quantization/types/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -3,7 +3,7 @@ from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType +from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -24,7 +24,7 @@ @QUANTMETHODS.register(["w8a8", "vllm-w8a8"]) -class W8A8Quantization(QuantizationMethod): +class W8A8Quantization(QuantizationMethod, PlatformAwareOp): def __init__(self): super().__init__() from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -33,26 +33,18 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = False - self._backend = QUANT_BACKEND.get_backend("w8a8") - - if self._backend == BackendType.TRITON: - if not HAS_VLLM: - raise NotImplementedError( - "W8A8 Triton fallback is not yet implemented. " - "Please install vLLM or disable LIGHTLLM_USE_TRITON_QUANT." - ) - self._backend = BackendType.VLLM - logger.warning("W8A8 Triton fallback not implemented, falling back to vLLM backend") - - if self._backend == BackendType.VLLM and not HAS_VLLM: - raise RuntimeError("vLLM is required for W8A8 quantization but is not installed.") - - logger.info(f"W8A8Quantization using backend: {self._backend.name}") - @property def method_name(self): return "w8a8" + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: weight = weight.float().cuda(self.device_id_) scale = weight.abs().max(dim=-1)[0] / 127 @@ -71,10 +63,9 @@ def apply( use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO: Currently only vLLM backend is implemented - return self._apply_vllm(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + return self._forward(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) - def _apply_vllm( + def _cuda_forward( self, input_tensor: torch.Tensor, weight_pack: WeightPack, @@ -98,11 +89,3 @@ def _apply_vllm( cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 05897203ae..d927f22d1b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -11,7 +11,6 @@ FusedMoeWeightEP, create_tp_moe_wegiht_obj, ) -from functools import partial from ..triton_kernel.weight_dequant import weight_dequant diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 821f94236b..0bf974df1f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -612,8 +612,16 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Hardware platform: cuda | musa""", ) parser.add_argument( - "--enable_torch_naive", + "--enable_torch_fallback", action="store_true", - help="""Use torch naive implementation for the op.""", + help="""Whether to enable torch naive implementation for the op. + If the op is not implemented for the platform, it will use torch naive implementation.""", + ) + parser.add_argument( + "--enable_triton_fallback", + action="store_true", + help="""Whether to enable triton implementation for the op. + If the op is not implemented for the platform and the hardware support triton, + it will use triton implementation.""", ) return parser From 55fdd2f81e2ec10ee6f44b76c5a1ff8b3f9f1f1c Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 15 Jan 2026 11:40:09 +0000 Subject: [PATCH 25/43] fix --- ...ernel.py => scaled_mm_per_token_kernel.py} | 33 ++-- lightllm/common/quantization/fp8_per_token.py | 2 +- lightllm/common/quantization/w8a8.py | 170 +++++++++++++++++- ...h.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json | 0 ...h.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json | 0 ...h.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json | 0 ...h.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json | 0 ...h.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 ...rch.bfloat16}_NVIDIA_GeForce_RTX_5090.json | 0 20 files changed, 191 insertions(+), 14 deletions(-) rename lightllm/common/basemodel/triton_kernel/quantization/{fp8w8a8_scaled_mm_per_token_kernel.py => scaled_mm_per_token_kernel.py} (93%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) rename lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/{fp8_scaled_mm_per_token:v3 => scaled_mm_per_token:v1}/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json (100%) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_scaled_mm_per_token_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py similarity index 93% rename from lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_scaled_mm_per_token_kernel.py rename to lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py index 7c76e82c9e..f14e8b2833 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_scaled_mm_per_token_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py @@ -11,8 +11,8 @@ from lightllm.utils.device_utils import triton_support_tensor_descriptor, is_5090_gpu -class Fp8ScaledMMKernelConfig(KernelConfigs): - kernel_name: str = "fp8_scaled_mm_per_token" +class ScaledMMKernelConfig(KernelConfigs): + kernel_name: str = "scaled_mm_per_token" @classmethod @lru_cache(maxsize=200) @@ -105,6 +105,7 @@ def _scaled_mm_per_token( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, + ACC_DTYPE: tl.constexpr, ): pid = tl.program_id(0) m_block_num = tl.cdiv(M, BLOCK_M) @@ -134,7 +135,7 @@ def _scaled_mm_per_token( a_s = tl.load(Ascale_ptrs) b_s = tl.load(Bscale_ptrs) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_K)): if USE_TMA: @@ -155,6 +156,7 @@ def _scaled_mm_per_token( a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk + acc = acc.to(tl.float32) acc = acc * a_s[:, None] * b_s[None, :] acc = acc.to(out.dtype.element_ty) @@ -206,13 +208,13 @@ def _get_static_key(A, B, out_dtype): @autotune( - kernel_name="fp8_scaled_mm_per_token:v3", + kernel_name="scaled_mm_per_token:v1", configs_gen_func=get_test_configs, static_key_func=_get_static_key, run_key_func=lambda A: A.shape[0], mutates_args=["out"], ) -def fp8_scaled_mm_per_token( +def scaled_mm_per_token( A: torch.Tensor, B: torch.Tensor, Ascale: torch.Tensor, @@ -221,7 +223,7 @@ def fp8_scaled_mm_per_token( out: torch.Tensor, run_config=None, ) -> torch.Tensor: - """w8a8fp8 per-token quantization mm. + """w8a8 per-token quantization mm (supports fp8 and int8). Args: A: Matrix A with shape of [M, K]. @@ -239,7 +241,7 @@ def fp8_scaled_mm_per_token( M, K = A.shape _, N = B.shape if not run_config: - run_config = Fp8ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) + run_config = ScaledMMKernelConfig.try_to_get_best_config(M=M, N=N, K=K, out_dtype=out_dtype) NEED_N_MASK = N % run_config["BLOCK_N"] != 0 NEED_K_MASK = K % run_config["BLOCK_K"] != 0 grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) @@ -283,6 +285,8 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): B_desc = None out_desc = None + ACC_DTYPE = tl.int32 if A.dtype == torch.int8 else tl.float32 + _scaled_mm_per_token[grid]( A=A, A_desc=A_desc, @@ -305,12 +309,17 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): B_IS_TRANS=B_is_trans, NEED_N_MASK=NEED_N_MASK, NEED_K_MASK=NEED_K_MASK, + ACC_DTYPE=ACC_DTYPE, **run_config, ) return out +fp8_scaled_mm_per_token = scaled_mm_per_token +int8_scaled_mm_per_token = scaled_mm_per_token + + if __name__ == "__main__": import time import os @@ -324,7 +333,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): M_list = [1, 2, 4, 8, 16, 32, 48] print(f"{'='*80}") - print(f"Starting Autotune for FP8 Scaled MM (N={N}, K={K})") + print(f"Starting Autotune for Scaled MM (N={N}, K={K})") print(f"M values to test: {M_list}") print(f"Total configs per M: {len(get_test_configs())}") print(f"{'='*80}\n") @@ -360,7 +369,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): gt_C = d_A.mm(d_B) # 运行kernel验证正确性 - fp8_scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) + scaled_mm_per_token(A_verify, B, Ascale_verify, Bscale, output_dtype, out_verify) # 计算cosine similarity cosine_sim = F.cosine_similarity(out_verify.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) @@ -390,7 +399,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): A = test_data[M]["A"] Ascale = test_data[M]["Ascale"] out = test_data[M]["out"] - fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) print(f"[M={M}] Autotune completed!") Autotuner.end_autotune_warmup() @@ -418,7 +427,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): gt_C = d_A.mm(d_B) # 运行一次确保结果正确 - fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) sgl_res = fp8_scaled_mm(A, B, Ascale, Bscale, output_dtype) cosine_sim = F.cosine_similarity(out.flatten().unsqueeze(0), gt_C.flatten().unsqueeze(0), dim=1) @@ -437,7 +446,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): ms_sgl = triton.testing.do_bench(fn_sgl, warmup=25, rep=100) # Our kernel - fn_ours = lambda: fp8_scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) + fn_ours = lambda: scaled_mm_per_token(A, B, Ascale, Bscale, output_dtype, out) ms_ours = triton.testing.do_bench_cudagraph(fn_ours, rep=100) print(f"[M={M}] BF16: {ms_bf16:.3f} ms") diff --git a/lightllm/common/quantization/fp8_per_token.py b/lightllm/common/quantization/fp8_per_token.py index c49bc89ff3..ce7f9342c9 100644 --- a/lightllm/common/quantization/fp8_per_token.py +++ b/lightllm/common/quantization/fp8_per_token.py @@ -4,7 +4,7 @@ from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.quantization.registry import QUANTMETHODS from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index f803794a29..721807356a 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -4,6 +4,10 @@ from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.quantization.registry import QUANTMETHODS from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import ( + fp8_scaled_mm_per_token, + int8_scaled_mm_per_token, +) from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -23,6 +27,27 @@ cutlass_scaled_mm = None +try: + from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops + + if HAS_LIGHTLLM_KERNEL: + + def scaled_fp8_quant(tensor, *args, **kwargs): + return light_ops.per_token_quant_bf16_fp8(tensor) + + else: + if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant + else: + scaled_fp8_quant = None +except ImportError: + HAS_LIGHTLLM_KERNEL = False + if HAS_VLLM: + scaled_fp8_quant = vllm_ops.scaled_fp8_quant + else: + scaled_fp8_quant = None + + @QUANTMETHODS.register(["w8a8", "vllm-w8a8"]) class W8A8Quantization(QuantizationMethod, PlatformAwareOp): def __init__(self): @@ -63,13 +88,53 @@ def apply( use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self._forward(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + return self._forward( + input_tensor=input_tensor, + weight_pack=weight_pack, + out=out, + workspace=workspace, + use_custom_tensor_mananger=use_custom_tensor_mananger, + bias=bias, + ) + + def _triton_forward( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + # TODO: support fp8 quantization triton + + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + out = int8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) + + if bias is not None: + out.add_(bias) + return out def _cuda_forward( self, input_tensor: torch.Tensor, weight_pack: WeightPack, out: Optional[torch.Tensor], + workspace: Optional[torch.Tensor], use_custom_tensor_mananger: bool, bias: Optional[torch.Tensor], ) -> torch.Tensor: @@ -89,3 +154,106 @@ def _cuda_forward( cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out + + +class Fp8W8A8Quantization(QuantizationMethod, PlatformAwareOp): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + self.is_moe = False + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "f8w8a8" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + """Quantize weights using per-token FP8 quantization.""" + qweight, weight_scale = scaled_fp8_quant( + weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + ) + output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) + output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) + return + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + if self.is_moe: + assert num_experts > 1, "Number of experts must be greater than 1 for MOE" + # per-tensor weight quantization for moe + weight_scale = torch.empty((num_experts,), dtype=torch.float32).cuda(device_id) + else: + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self._forward(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) + + def _cuda_forward( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) + return out + + def _apply_triton( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor], + use_custom_tensor_mananger: bool, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + qweight = weight_pack.weight.t() + weight_scale = weight_pack.weight_scale + + # TODO: support fp8 quantization triton + + x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) + + m = input_tensor.shape[0] + n = qweight.shape[1] + + if out is None: + if use_custom_tensor_mananger: + out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + + out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) + + if bias is not None: + out.add_(bias) + return out diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_4090_D.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=13824,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=14336,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=1536,N=8960,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=4096,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=13824,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=2048,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=28672,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=4096,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=5120,N=5120,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{K=8960,N=1536,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json similarity index 100% rename from lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/fp8_scaled_mm_per_token:v3/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json rename to lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/scaled_mm_per_token:v1/{N=14336,out_dtype=torch.bfloat16}_NVIDIA_GeForce_RTX_5090.json From 208f1b09d0137d3fb2e9ce5b7a9d50798aaf3476 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 15 Jan 2026 11:40:53 +0000 Subject: [PATCH 26/43] unit_test --- unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py index 2c0b7bf76e..e6a0d52c72 100644 --- a/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py +++ b/unit_tests/common/quantization/test_fp8_scaled_mm_per_token.py @@ -1,7 +1,7 @@ import torch import pytest import torch.nn.functional as F -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token def is_fp8_native_supported(): From 206b170917aba883b9342e2db0c4b158abf32db3 Mon Sep 17 00:00:00 2001 From: sangchengmeng <101796078+SangChengC@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:19:18 +0800 Subject: [PATCH 27/43] check image tag and image num (#1176) Co-authored-by: sangchengmeng --- .../int8kv/int8kv_flash_decoding_diverse_stage1.py | 2 +- lightllm/models/internvl/model.py | 8 ++++++++ lightllm/models/qwen2_vl/model.py | 4 ++++ lightllm/models/qwen2_vl/vision_process.py | 2 ++ lightllm/models/qwen_vl/model.py | 3 ++- lightllm/models/tarsier2/model.py | 4 ++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 4dfaffef68..295ae66ab3 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -291,7 +291,7 @@ def flash_decode_stage1( assert k.stride() == v.stride() NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS - + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 6d264a4267..ccb76d3512 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -149,6 +149,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) # audio @@ -174,6 +178,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("audio token error") except ValueError: break + if multimodal_params: + audio_cnt = len(multimodal_params.audios) + if audio_cnt != audio_id: + raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index dd4181fbfb..237c4ad897 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -79,6 +79,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids) return input_ids diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 1c4f60794d..f2cd38ec8e 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -184,6 +184,8 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: return self._preprocess_bydevice(image, device="cpu") def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: + if image.mode != "RGB": + image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index d942d68497..0c6fa31f47 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -86,7 +86,8 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): input_ids.extend(origin_ids[end:]) if multimodal_params: image_cnt = len(multimodal_params.images) - assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") return input_ids diff --git a/lightllm/models/tarsier2/model.py b/lightllm/models/tarsier2/model.py index dad252b979..10a7f368c4 100644 --- a/lightllm/models/tarsier2/model.py +++ b/lightllm/models/tarsier2/model.py @@ -78,6 +78,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids From 7a0a4d7216eba852c45aec7e9a1bbed95a3dc326 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 13:47:40 +0000 Subject: [PATCH 28/43] fix --- .../common/basemodel/attention/fa3/fp8.py | 2 +- .../layer_weights/base_layer_weight.py | 9 - .../meta_weights/att_sink_weight.py | 3 - .../meta_weights/embedding_weight.py | 40 ++- .../fused_moe/fused_moe_weight_ep.py | 5 +- .../fused_moe/fused_moe_weight_tp.py | 1 + .../meta_weights/mm_weight/mm_weight.py | 3 - .../layer_weights/meta_weights/norm_weight.py | 29 +- .../basemodel/triton_kernel/norm/__init__.py | 0 .../common/fused_moe/grouped_fused_moe.py | 2 +- .../common/fused_moe/grouped_fused_moe_ep.py | 2 +- lightllm/common/quantization/__init__.py | 95 +++++- lightllm/common/quantization/awq.py | 267 +++++++-------- lightllm/common/quantization/deepgemm.py | 133 ++++++++ lightllm/common/quantization/fp8_block128.py | 216 ------------ lightllm/common/quantization/fp8_per_token.py | 172 ---------- lightllm/common/quantization/no_quant.py | 3 +- lightllm/common/quantization/registry.py | 16 +- lightllm/common/quantization/w8a8.py | 309 +++++++++--------- .../pre_and_post_layer_weight.py | 24 +- .../layer_infer/transformer_layer_infer.py | 14 +- .../pre_and_post_layer_weight.py | 23 +- .../pre_and_post_layer_weight.py | 23 +- lightllm/server/api_cli.py | 4 +- 24 files changed, 583 insertions(+), 812 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/norm/__init__.py create mode 100644 lightllm/common/quantization/deepgemm.py delete mode 100644 lightllm/common/quantization/fp8_block128.py delete mode 100644 lightllm/common/quantization/fp8_per_token.py diff --git a/lightllm/common/basemodel/attention/fa3/fp8.py b/lightllm/common/basemodel/attention/fa3/fp8.py index 3feed1ef46..12b2b0dfa8 100644 --- a/lightllm/common/basemodel/attention/fa3/fp8.py +++ b/lightllm/common/basemodel/attention/fa3/fp8.py @@ -4,7 +4,7 @@ from typing import Optional, TYPE_CHECKING from lightllm.utils.sgl_utils import flash_attn_with_kvcache from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant +from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops from typing import Union from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 1875e2c3b3..6bdeb64d20 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -26,14 +26,5 @@ def init_static_params(self): """ pass - def verify_load(self): - """ - verify all load is ok - """ - for attr_name in dir(self): - attr = getattr(self, attr_name) - if isinstance(attr, BaseWeight): - assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails." - def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 3f8e1f50ab..1c22bcb7d9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -18,6 +18,3 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): t_weight = weights[self.weight_name] start_head_index, end_head_index = self._get_head_tp_split_params(weight=t_weight) self.weight = t_weight[start_head_index:end_head_index].to(self.data_type_).cuda(get_current_device_id()) - - def verify_load(self): - return self.weight is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index e228d5c869..df9050d4fc 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -44,18 +44,15 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): def _native_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty ) -> torch.Tensor: - # Adjust input_ids for tp split adjusted_ids = input_ids - self.tp_vocab_start_id - # Clamp to valid range for this partition adjusted_ids = torch.clamp(adjusted_ids, 0, self.weight.shape[0] - 1) - # Use PyTorch native embedding result = torch.nn.functional.embedding(adjusted_ids, self.weight) if out is not None: out.copy_(result) return out return result - def _cuda_forward( + def _triton_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: if out is None: @@ -71,6 +68,17 @@ def _cuda_forward( ) return out + def _cuda_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + return self._triton_forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + + def _musa_forward( + self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + def __call__( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: @@ -84,7 +92,7 @@ def __init__( vocab_size: int, weight_name: str, data_type: torch.dtype, - shared_weight: Optional[EmbeddingWeight] = None, + embedding_weight: Optional[EmbeddingWeight] = None, ): super().__init__() self.dim = dim @@ -97,23 +105,19 @@ def __init__( self.tp_vocab_end_id = int(split_indexes[self.tp_rank_ + 1]) self.weight_name: str = weight_name self.data_type_ = data_type - self._shared_weight = shared_weight - if shared_weight is None: - self._create_weight() - - @property - def weight(self) -> torch.Tensor: - if self._shared_weight is not None: - return self._shared_weight.weight - return self._weight + self._embedding_weight = embedding_weight + self._create_weight() def _create_weight(self): + if self._embedding_weight is not None: + self.weight = self._embedding_weight.weight + return tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id - self._weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) def load_hf_weights(self, weights: Dict[str, torch.Tensor]): - # When using shared weight, no need to load - EmbeddingWeight already loaded it - if self._shared_weight is not None: + # When set tile_embedding=True, no need to load - EmbeddingWeight already loaded it + if self._embedding_weight is not None: return if self.weight_name not in weights: return @@ -123,7 +127,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): loaded_vocab_size == self.vocab_size ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") - self._weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) def _native_forward( self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index a84d198937..342026de21 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -15,7 +15,7 @@ from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, tma_align_input_scale, ) @@ -741,6 +741,3 @@ def _cuda(self, cpu_tensor): if self.quantized_weight: return cpu_tensor.contiguous().cuda(self.device_id_) return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) - - def verify_load(self): - return self.w1 is not None and self.w2 is not None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index c6b3dc9656..876dc44bd2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -90,6 +90,7 @@ def __init__( self.num_fused_shared_experts = num_fused_shared_experts self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.split_inter_size = split_inter_size + self.data_type_ = data_type self.hidden_size = network_config.get("hidden_size") self.e_score_correction_bias = None self.scoring_func = network_config.get("scoring_func", "softmax") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index a7288b8187..728ed82fa9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -120,9 +120,6 @@ def load_hf_weights(self, weights): for sub_child_index, param_name in enumerate(self.weight_zero_point_names): self._load_weight_zero_point(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - def verify_load(self) -> bool: - return True - def _create_weight(self): self.bias = None if self.bias_names is not None: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index df12ec9b1b..d7bbe5567a 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -2,9 +2,9 @@ from typing import Optional, Dict from .base_weight import BaseWeightTpl from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size -from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.common.basemodel.triton_kernel.layernorm import layernorm_forward -from lightllm.common.basemodel.triton_kernel.qk_norm import qk_rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_forward from .platform_op import PlatformAwareOp @@ -53,6 +53,12 @@ def _cuda_forward( # only triton implementation is supported for rmsnorm on cuda platform return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def _musa_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: @@ -105,6 +111,12 @@ def _cuda_forward( # only triton implementation is supported for layernorm on cuda platform return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def _musa_forward( + self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty + ) -> torch.Tensor: + # triton implementation is supported by musa. + return self._triton_forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: @@ -188,15 +200,22 @@ def _native_forward( input.copy_(x) return + def _triton_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + assert input.ndim == 2 and self.weight.ndim == 1 + return qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps) + def _cuda_forward( self, input: torch.Tensor, eps: float, ) -> None: - assert input.ndim == 2 and self.weight.ndim == 1 - qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps) + self._triton_forward(input=input, eps=eps) return + def _musa_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + # musa implementation is supported by musa triton on musa platform + return self._triton_forward(input=input, eps=eps) + def __call__( self, input: torch.Tensor, diff --git a/lightllm/common/basemodel/triton_kernel/norm/__init__.py b/lightllm/common/basemodel/triton_kernel/norm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py index 758d83ba34..f29d3a2a0e 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/fused_moe/grouped_fused_moe.py @@ -28,7 +28,7 @@ from .moe_kernel_configs import MoeGroupedGemmKernelConfig from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.utils.torch_ops_utils import direct_register_custom_op from lightllm.common.triton_utils.autotuner import autotune diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/fused_moe/grouped_fused_moe_ep.py index 5cc0d7a9be..2a577890b2 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/fused_moe/grouped_fused_moe_ep.py @@ -8,7 +8,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, tma_align_input_scale, ) diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 8cbcc2e684..bf99622ef2 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -1,13 +1,82 @@ -from .no_quant import NoQuantization -from .fp8_block128 import FP8Block128Quantization -from .fp8_per_token import FP8PerTokenQuantization -from .w8a8 import W8A8Quantization -from .awq import AWQQuantization - -__all__ = [ - "NoQuantization", - "FP8Block128Quantization", - "FP8PerTokenQuantization", - "W8A8Quantization", - "AWQQuantization", -] +import yaml +import collections +from .registry import QUANTMETHODS +from .w8a8 import * +from .deepgemm import * +from .awq import * +from .no_quant import * +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Quantcfg: + def __init__(self, network_config, quant_type="none", custom_cfg_path=None): + self.layer_num = network_config["n_layer"] + self.quant_type = quant_type + self.network_config_ = network_config + self._parse_custom_cfg(custom_cfg_path) + self._parse_network_config(network_config) + + def _parse_network_config(self, network_config): + hf_quantization_config = network_config.get("quantization_config", None) + if hf_quantization_config is None: + self.quantized_weight = False + self.static_activation = False + self.hf_quantization_config = None + return + self.quantized_weight = True + activation_scheme = network_config.get("activation_scheme", "dynamic") + self.static_activation = activation_scheme == "static" + self.hf_quantization_config = hf_quantization_config + self.hf_quantization_method = hf_quantization_config["quant_method"] + self._mapping_quant_method() + + def _mapping_quant_method(self): + if self.hf_quantization_method == "fp8": + block_size = self.hf_quantization_config.get("weight_block_size", None) + if block_size == [128, 128]: + from lightllm.common.quantization.deepgemm_quant import HAS_DEEPGEMM + + if HAS_DEEPGEMM: + self.quant_type = "deepgemm-fp8w8a8-b128" + else: + self.quant_type = "vllm-fp8w8a8-b128" + logger.info(f"select fp8w8a8-b128 quant way: {self.quant_type}") + elif self.hf_quantization_method == "awq": + self.quant_type = "awq" + if is_awq_marlin_compatible(self.hf_quantization_config): + self.quant_type = "awq_marlin" + logger.info(f"select awq quant way: {self.quant_type}") + else: + # TODO: more quant method + pass + + def _parse_custom_cfg(self, custom_cfg_path): + self.quant_cfg = collections.defaultdict(dict) + if custom_cfg_path is None: + return + + with open(custom_cfg_path, "r") as file: + data = yaml.safe_load(file) + + self.quant_type = data["quant_type"] + for layer_quant_cfg in data.get("mix_bits", []): + name = layer_quant_cfg["name"] + layer_nums = layer_quant_cfg.get("layer_nums", range(self.layer_num)) + layer_quant_type = layer_quant_cfg["quant_type"] + for layer_num in layer_nums: + self.quant_cfg[layer_num].update({name: layer_quant_type}) + + def get_quant_type(self, layer_num, name): + layer_config = self.quant_cfg.get(layer_num, None) + if layer_config is None: + return self.quant_type + quant_type = layer_config.get(name, self.quant_type) + return quant_type + + def get_quant_method(self, layer_num, name): + quant_type = self.get_quant_type(layer_num, name) + quant_method = QUANTMETHODS.get(quant_type) + quant_method.hf_quantization_config = self.hf_quantization_config + return quant_method diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index eedc5b67b5..ddb7674dd1 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -39,40 +39,37 @@ TYPE_MAP = {} -def is_awq_marlin_compatible(quantization_config: dict[str, Any]) -> bool: - if not HAS_VLLM: - return False - - quant_method = quantization_config.get("quant_method", "").lower() - num_bits = quantization_config.get("bits") - group_size = quantization_config.get("group_size") - zero_point = quantization_config.get("zero_point") - - if not torch.cuda.is_available(): - return False +class AWQBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - if quant_method != "awq": - return False + self.cache_manager = g_cache_manager - if num_bits is None or group_size is None or zero_point is None: - return False + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + raise NotImplementedError("AWQ online quantization is not supported yet.") - if num_bits not in TYPE_MAP: - return False + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("AWQ online quantization is not supported yet.") - return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) + @property + def method_name(self): + return "awq-base" -@QUANTMETHODS.register(["awq", "awq_marlin"]) -class AWQQuantization(QuantizationMethod): +@QUANTMETHODS.register("awq", platform="cuda") +class AWQW4A16QuantizationMethod(AWQBaseQuantizationMethod): def __init__(self): super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - if not HAS_VLLM: - raise RuntimeError("vLLM is required for AWQ quantization but is not installed.") - - self.cache_manager = g_cache_manager self.pack_factor = 8 self.weight_scale_suffix = "scales" self.weight_zero_point_suffix = "qzeros" @@ -80,38 +77,11 @@ def __init__(self): self.has_weight_scale = True self.has_weight_zero_point = True - self._use_marlin = False - self._marlin_initialized = False - - def _init_marlin(self): - if self._marlin_initialized: - return - - self.nbits = 4 - self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) - self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) - self.workspace = marlin_make_workspace_new(torch.device("cuda")) - self.vllm_quant_type = TYPE_MAP[self.nbits] - self.tile_size = 16 - self._marlin_initialized = True - - def _check_and_set_marlin(self): - if self.hf_quantization_config is None: - self._use_marlin = False - return - - self._use_marlin = is_awq_marlin_compatible(self.hf_quantization_config) - if self._use_marlin: - self._init_marlin() - logger.info("AWQQuantization using Marlin backend") - else: - logger.info("AWQQuantization using basic AWQ backend") - @property def method_name(self): return "awq" - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): raise NotImplementedError("AWQ online quantization is not supported yet.") def apply( @@ -122,22 +92,6 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if not hasattr(self, "_checked_marlin"): - self._check_and_set_marlin() - self._checked_marlin = True - - if self._use_marlin: - return self._apply_marlin(input_tensor, weight_pack, out, bias) - else: - return self._apply_basic(input_tensor, weight_pack, out, bias) - - def _apply_basic( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - bias: Optional[torch.Tensor], ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -154,12 +108,81 @@ def _apply_basic( out.add_(bias) return out - def _apply_marlin( + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + group_size = self.hf_quantization_config["group_size"] + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight_zero_point = torch.empty( + expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + start_idx = start_idx // self.pack_factor + end_idx = start_idx + weight_zero_point.shape[1] + weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) + return + + +@QUANTMETHODS.register("awq_marlin", platform="cuda") +class AWQMARLINW4A16QuantizationMethod(AWQBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.pack_factor = 8 + self.nbits = 4 + self.weight_scale_suffix = "scales" + self.weight_zero_point_suffix = "qzeros" + self.weight_suffix = "qweight" + self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) + self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) + self.workspace = marlin_make_workspace_new(torch.device("cuda")) + self.vllm_quant_type = TYPE_MAP[self.nbits] + self.has_weight_scale = True + self.has_weight_zero_point = True + self.tile_size = 16 + + @property + def method_name(self): + return "awq_marlin" + + def quantize(self, weight: torch.Tensor, offset: int = 0) -> WeightPack: + raise NotImplementedError("AWQ online quantization is not supported yet.") + + def params_repack( + self, weight: torch.Tensor, weight_scale: torch.Tensor, weight_zero_point: torch.Tensor, dtype_type: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 一些量化方法在将参数完成量化后,为了加速性能,还需要将参数进行重拍,使算子性能达到最优,如awq方法。 + """ + weight = self._process_weight_after_loading(weight.cuda(get_current_device_id())) + weight_scale = self._process_weight_scale_after_loading( + weight_scale.cuda(get_current_device_id()).to(dtype_type) + ) + weight_zero_point = self._process_weight_zero_point_after_loading( + weight_zero_point.cuda(get_current_device_id()) + ) + return weight, weight_scale, weight_zero_point + + def apply( self, input_tensor: torch.Tensor, weight_pack: WeightPack, - out: Optional[torch.Tensor], - bias: Optional[torch.Tensor], + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight weight_scale = weight_pack.weight_scale @@ -200,30 +223,6 @@ def _apply_marlin( def create_weight( self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - if not hasattr(self, "_checked_marlin"): - self._check_and_set_marlin() - self._checked_marlin = True - - if self._use_marlin: - return self._create_weight_marlin(out_dim, in_dim, dtype, device_id, num_experts) - else: - return self._create_weight_basic(out_dim, in_dim, dtype, device_id, num_experts) - - def _create_weight_basic( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - group_size = self.hf_quantization_config["group_size"] - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) - weight_zero_point = torch.empty( - expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) - - def _create_weight_marlin( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 ) -> WeightPack: self.n = out_dim self.k = in_dim @@ -239,20 +238,6 @@ def _create_weight_marlin( return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - if not hasattr(self, "_checked_marlin"): - self._check_and_set_marlin() - self._checked_marlin = True - - if self._use_marlin: - self._load_weight_marlin(weight, weight_pack, start_idx) - else: - self._load_weight_basic(weight, weight_pack, start_idx) - - def _load_weight_basic(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - start_idx = start_idx // self.pack_factor - weight_pack.weight[:, start_idx : start_idx + weight.shape[1]].copy_(weight) - - def _load_weight_marlin(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: assert self.hf_quantization_config is not None, "hf_quantization_config is not set" device_id = get_current_device_id() repack_weight = vllm_ops.awq_marlin_repack( @@ -263,21 +248,9 @@ def _load_weight_marlin(self, weight: torch.Tensor, weight_pack: WeightPack, sta ) start_idx = start_idx // self.pack_factor * self.tile_size weight_pack.weight[:, start_idx : start_idx + repack_weight.shape[1]].copy_(repack_weight) + return def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - if not hasattr(self, "_checked_marlin"): - self._check_and_set_marlin() - self._checked_marlin = True - - if self._use_marlin: - self._load_weight_scale_marlin(weight_scale, weight_pack, start_idx) - else: - self._load_weight_scale_basic(weight_scale, weight_pack, start_idx) - - def _load_weight_scale_basic(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[:, start_idx : start_idx + weight_scale.shape[1]].copy_(weight_scale) - - def _load_weight_scale_marlin(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: assert self.hf_quantization_config is not None, "hf_quantization_config is not set" group_size = self.hf_quantization_config["group_size"] device_id = get_current_device_id() @@ -288,27 +261,9 @@ def _load_weight_scale_marlin(self, weight_scale: torch.Tensor, weight_pack: Wei group_size=self.hf_quantization_config["group_size"], ) weight_pack.weight_scale[:, start_idx : start_idx + repack_weight_scale.shape[1]].copy_(repack_weight_scale) + return def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - if not hasattr(self, "_checked_marlin"): - self._check_and_set_marlin() - self._checked_marlin = True - - if self._use_marlin: - self._load_weight_zero_point_marlin(weight_zero_point, weight_pack, start_idx) - else: - self._load_weight_zero_point_basic(weight_zero_point, weight_pack, start_idx) - - def _load_weight_zero_point_basic( - self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int - ) -> None: - start_idx = start_idx // self.pack_factor - end_idx = start_idx + weight_zero_point.shape[1] - weight_pack.weight_zero_point[:, start_idx:end_idx].copy_(weight_zero_point) - - def _load_weight_zero_point_marlin( - self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int - ) -> None: device_id = get_current_device_id() repack_weight_zero_point = awq_to_marlin_zero_points( weight_zero_point.cuda(device_id), @@ -320,3 +275,29 @@ def _load_weight_zero_point_marlin( weight_pack.weight_zero_point[:, start_idx : start_idx + repack_weight_zero_point.shape[1]].copy_( repack_weight_zero_point ) + return + + +# adapted from +# https://github.com/vllm-project/vllm/blob/aef368aa08572505b820db01da82e2fbb3d43a72/vllm/model_executor/layers/quantization/awq_marlin.py#L211-L212 +def is_awq_marlin_compatible(quantization_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quantization_config.get("quant_method", "").lower() + num_bits = quantization_config.get("bits") + group_size = quantization_config.get("group_size") + zero_point = quantization_config.get("zero_point") + + if not torch.cuda.is_available(): + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if num_bits is None or group_size is None or zero_point is None: + return False + + if num_bits not in TYPE_MAP: + return False + + return check_marlin_supported(quant_type=TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point) diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py new file mode 100644 index 0000000000..c9f227120d --- /dev/null +++ b/lightllm/common/quantization/deepgemm.py @@ -0,0 +1,133 @@ +import torch +from typing import Optional + +from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.quantization.registry import QUANTMETHODS +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +try: + import deep_gemm + + HAS_DEEPGEMM = True +except ImportError: + HAS_DEEPGEMM = False + + +class DeepGEMMBaseQuantizationMethod(QuantizationMethod): + def __init__(self): + super().__init__() + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + + self.cache_manager = g_cache_manager + assert HAS_DEEPGEMM, "deepgemm is not installed, you can't use quant api of it" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + raise NotImplementedError("Not implemented") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") + + @property + def method_name(self): + return "deepgemm-base" + + +@QUANTMETHODS.register(["deepgemm-fp8w8a8-b128"], platform="cuda") +class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 128 + self.weight_suffix = None + self.weight_zero_point_suffix = None + self.weight_scale_suffix = "weight_scale_inv" + self.has_weight_scale = True + self.has_weight_zero_point = False + + @property + def method_name(self): + return "deepgemm-fp8w8a8-b128" + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0): + from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_quant_kernel import weight_quant + + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: "WeightPack", + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = weight_pack.weight + weight_scale = weight_pack.weight_scale + input_scale = None + alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty + m, k = input_tensor.shape + n = qweight.shape[0] + if input_scale is None: + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, + self.block_size, + dtype=qweight.dtype, + column_major_scales=True, + scale_tma_aligned=True, + alloc_func=alloc_func, + ) + + if out is None: + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) + return out + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) + return + + def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_scale[ + start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size + ].copy_(weight_scale) + return + + def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: + weight_pack.weight_zero_point[ + start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size + ].copy_(weight_zero_point) + return + + +def _deepgemm_fp8_nt(a_tuple, b_tuple, out): + if HAS_DEEPGEMM: + if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): + return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) + if hasattr(deep_gemm, "fp8_gemm_nt"): + return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) + raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") diff --git a/lightllm/common/quantization/fp8_block128.py b/lightllm/common/quantization/fp8_block128.py deleted file mode 100644 index 4144dddde2..0000000000 --- a/lightllm/common/quantization/fp8_block128.py +++ /dev/null @@ -1,216 +0,0 @@ -import torch -from typing import Optional - -from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack -from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -try: - import deep_gemm - - HAS_DEEPGEMM = True -except ImportError: - HAS_DEEPGEMM = False - -try: - from lightllm.utils.vllm_utils import HAS_VLLM - - if HAS_VLLM: - from lightllm.utils.vllm_utils import cutlass_scaled_mm - else: - cutlass_scaled_mm = None -except ImportError: - HAS_VLLM = False - cutlass_scaled_mm = None - - -def _deepgemm_fp8_nt(a_tuple, b_tuple, out): - if hasattr(deep_gemm, "gemm_fp8_fp8_bf16_nt"): - return deep_gemm.gemm_fp8_fp8_bf16_nt([a_tuple[0], a_tuple[1]], [b_tuple[0], b_tuple[1]], out) - if hasattr(deep_gemm, "fp8_gemm_nt"): - return deep_gemm.fp8_gemm_nt((a_tuple[0], a_tuple[1]), (b_tuple[0], b_tuple[1]), out) - raise RuntimeError("deep_gemm does not provide fp8 NT GEMM kernel in this version") - - -@QUANTMETHODS.register(["fp8-block128"]) -class FP8Block128Quantization(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - self.block_size = 128 - self.weight_scale_suffix = "weight_scale_inv" - self.has_weight_scale = True - self.has_weight_zero_point = False - - self._backend = QUANT_BACKEND.get_backend("fp8-block128") - logger.info(f"FP8Block128Quantization using backend: {self._backend.name}") - - @property - def method_name(self): - return "fp8-block128" - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant - - device = output.weight.device - weight, scale = weight_quant(weight.cuda(device), self.block_size) - output.weight[offset : offset + weight.shape[0], :].copy_(weight) - output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) - return - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty - m, k = input_tensor.shape - - if self._backend == BackendType.DEEPGEMM: - return self._apply_deepgemm(input_tensor, weight_pack, out, alloc_func, bias) - elif self._backend == BackendType.VLLM: - return self._apply_vllm(input_tensor, weight_pack, out, alloc_func, bias) - else: - return self._apply_triton(input_tensor, weight_pack, out, alloc_func, bias) - - def _apply_deepgemm( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - alloc_func, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - m, k = input_tensor.shape - n = qweight.shape[0] - - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, - self.block_size, - dtype=qweight.dtype, - column_major_scales=True, - scale_tma_aligned=True, - alloc_func=alloc_func, - ) - - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - _deepgemm_fp8_nt((qinput_tensor, input_scale), (qweight, weight_scale), out) - - if bias is not None: - out.add_(bias) - return out - - def _apply_vllm( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - alloc_func, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale.t() - m, k = input_tensor.shape - n = qweight.shape[1] - - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func - ) - - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - if n % 128 != 0: - w8a8_block_fp8_matmul( - qinput_tensor, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - else: - input_scale = input_scale.t().contiguous().t() - cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) - return out - - if bias is not None: - out.add_(bias) - return out - - def _apply_triton( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - alloc_func, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight - weight_scale = weight_pack.weight_scale - m, k = input_tensor.shape - n = qweight.shape[1] - - qinput_tensor, input_scale = per_token_group_quant_fp8( - input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func - ) - - if out is None: - out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - w8a8_block_fp8_matmul( - qinput_tensor, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - - if bias is not None: - out.add_(bias) - return out - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty( - expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) - - def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight[start_idx : start_idx + weight.shape[0]].copy_(weight) - return - - def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[ - start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size - ].copy_(weight_scale) - return - - def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - if weight_pack.weight_zero_point is not None: - weight_pack.weight_zero_point[ - start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size - ].copy_(weight_zero_point) - return diff --git a/lightllm/common/quantization/fp8_per_token.py b/lightllm/common/quantization/fp8_per_token.py deleted file mode 100644 index ce7f9342c9..0000000000 --- a/lightllm/common/quantization/fp8_per_token.py +++ /dev/null @@ -1,172 +0,0 @@ -import torch -from typing import Optional - -from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack -from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.common.quantization.backend import QUANT_BACKEND, BackendType -from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -try: - from lightllm.utils.vllm_utils import HAS_VLLM - - if HAS_VLLM: - from lightllm.utils.vllm_utils import vllm_ops, cutlass_scaled_mm - else: - vllm_ops = None - cutlass_scaled_mm = None -except ImportError: - HAS_VLLM = False - vllm_ops = None - cutlass_scaled_mm = None - -try: - from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - if HAS_LIGHTLLM_KERNEL: - - def scaled_fp8_quant(tensor, *args, **kwargs): - return light_ops.per_token_quant_bf16_fp8(tensor) - - else: - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant - else: - scaled_fp8_quant = None -except ImportError: - HAS_LIGHTLLM_KERNEL = False - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant - else: - scaled_fp8_quant = None - - -@QUANTMETHODS.register(["fp8-per-token", "fp8w8a8"]) -class FP8PerTokenQuantization(QuantizationMethod): - def __init__(self): - super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager - self.is_moe = False - self.has_weight_scale = True - self.has_weight_zero_point = False - self._backend = QUANT_BACKEND.get_backend("fp8-per-token") - logger.info(f"FP8PerTokenQuantization using backend: {self._backend.name}") - - @property - def method_name(self): - return "fp8-per-token" - - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - """Quantize weights using per-token FP8 quantization.""" - if self.is_moe: - return self._quantize_moe(weight, output, offset) - - if scaled_fp8_quant is None: - raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") - - qweight, weight_scale = scaled_fp8_quant( - weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - output.weight[offset : offset + qweight.shape[0], :].copy_(qweight) - output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) - return - - def _quantize_moe(self, weight: torch.Tensor, output: WeightPack, offset: int) -> None: - if scaled_fp8_quant is None: - raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") - - num_experts = weight.shape[0] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) - weight_scales = [] - for i in range(num_experts): - qweight, weight_scale = scaled_fp8_quant( - weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.stack(weight_scales, dim=0).contiguous() - output.weight.copy_(qweights) - output.weight_scale.copy_(weight_scale) - return - - def apply( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - use_custom_tensor_mananger: bool = True, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if self._backend == BackendType.TRITON: - return self._apply_triton(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) - else: - return self._apply_vllm(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) - - def _apply_vllm( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - - m = input_tensor.shape[0] - n = qweight.shape[1] - - if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) - return out - - def _apply_triton( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - - if scaled_fp8_quant is None: - raise RuntimeError("No FP8 quantization kernel available. Install vLLM or lightllm-kernel.") - - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - - m = input_tensor.shape[0] - n = qweight.shape[1] - - if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) - - if bias is not None: - out.add_(bias) - return out - - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index e92d821c15..c05c90b210 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -5,7 +5,8 @@ from lightllm.common.quantization.registry import QUANTMETHODS -@QUANTMETHODS.register("none") +@QUANTMETHODS.register("none", platform="musa") +@QUANTMETHODS.register("none", platform="cuda") class NoQuantization(QuantizationMethod): """No quantization - uses full precision weights.""" diff --git a/lightllm/common/quantization/registry.py b/lightllm/common/quantization/registry.py index e9b4073987..c9baa64e27 100644 --- a/lightllm/common/quantization/registry.py +++ b/lightllm/common/quantization/registry.py @@ -5,21 +5,27 @@ class QuantMethodFactory: def __init__(self): self._quant_methods = {} - def register(self, names): + def register(self, names, platform="cuda"): def decorator(cls): local_names = names if isinstance(local_names, str): local_names = [local_names] for n in local_names: - self._quant_methods[n] = cls + if n not in self._quant_methods: + self._quant_methods[n] = {} + self._quant_methods[n][platform] = cls return cls return decorator - def get(self, key, *args, **kwargs) -> "QuantizationMethod": - quant_method_class = self._quant_methods.get(key) - if not quant_method_class: + def get(self, key, platform="cuda", *args, **kwargs) -> "QuantizationMethod": + quant_method_class_dict = self._quant_methods.get(key) + if not quant_method_class_dict: raise ValueError(f"QuantMethod '{key}' not supported.") + + quant_method_class = quant_method_class_dict.get(platform) + if quant_method_class is None: + raise ValueError(f"QuantMethod '{key}' for platform '{platform}' not supported.") return quant_method_class() diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 721807356a..0a74d9887c 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -1,74 +1,72 @@ +import os import torch +import torch.nn.functional as F from typing import Optional +from .quantize_method import QuantizationMethod +from .registry import QUANTMETHODS +from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import fp8_scaled_mm_per_token +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops, cutlass_scaled_mm +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops -from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack -from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.quantization.scaled_mm_per_token_kernel import ( - fp8_scaled_mm_per_token, - int8_scaled_mm_per_token, -) -from lightllm.utils.log_utils import init_logger -logger = init_logger(__name__) +from .quantize_method import WeightPack -# Conditional imports for optional backends -try: - from lightllm.utils.vllm_utils import HAS_VLLM +if HAS_LIGHTLLM_KERNEL: - if HAS_VLLM: - from lightllm.utils.vllm_utils import vllm_ops, cutlass_scaled_mm - else: - vllm_ops = None - cutlass_scaled_mm = None -except ImportError: - HAS_VLLM = False - vllm_ops = None - cutlass_scaled_mm = None - - -try: - from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops - - if HAS_LIGHTLLM_KERNEL: - - def scaled_fp8_quant(tensor, *args, **kwargs): - return light_ops.per_token_quant_bf16_fp8(tensor) + def scaled_fp8_quant(tensor, *args, **kwargs): + return light_ops.per_token_quant_bf16_fp8(tensor) - else: - if HAS_VLLM: - scaled_fp8_quant = vllm_ops.scaled_fp8_quant - else: - scaled_fp8_quant = None -except ImportError: - HAS_LIGHTLLM_KERNEL = False +else: if HAS_VLLM: scaled_fp8_quant = vllm_ops.scaled_fp8_quant - else: - scaled_fp8_quant = None + +LIGHTLLM_USE_TRITON_FP8_SCALED_MM = os.getenv("LIGHTLLM_USE_TRITON_FP8_SCALED_MM", "False").upper() in [ + "ON", + "TRUE", + "1", +] -@QUANTMETHODS.register(["w8a8", "vllm-w8a8"]) -class W8A8Quantization(QuantizationMethod, PlatformAwareOp): +class BaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() + assert HAS_VLLM, "vllm are not installed, you can't use quant api of them." from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager self.cache_manager = g_cache_manager - self.has_weight_scale = True - self.has_weight_zero_point = False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + raise NotImplementedError("Not implemented") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError("Not implemented") @property def method_name(self): - return "w8a8" + return "w8a8-base" def create_weight( self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) + raise NotImplementedError("Not implemented") + + +@QUANTMETHODS.register(["vllm-w8a8", "w8a8"], platform="cuda") +class w8a8QuantizationMethod(BaseQuantizationMethod): + def __init__(self): + super().__init__() + self.has_weight_scale = True + self.has_weight_zero_point = False def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: weight = weight.float().cuda(self.device_id_) @@ -88,90 +86,45 @@ def apply( use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self._forward( - input_tensor=input_tensor, - weight_pack=weight_pack, - out=out, - workspace=workspace, - use_custom_tensor_mananger=use_custom_tensor_mananger, - bias=bias, - ) - - def _triton_forward( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - workspace: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - + input_scale = None qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - - # TODO: support fp8 quantization triton - - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - + input_scale = None # dynamic quantization for input tensor + x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=input_scale, azp=None, symmetric=True) m = input_tensor.shape[0] n = qweight.shape[1] - if out is None: if use_custom_tensor_mananger: out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) else: out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - out = int8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) - - if bias is not None: - out.add_(bias) + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out - def _cuda_forward( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - workspace: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - - x_q, x_scale, x_zp = vllm_ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True) - - m = input_tensor.shape[0] - n = qweight.shape[1] - - if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + @property + def method_name(self): + return "vllm-w8a8" - cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) - return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) -class Fp8W8A8Quantization(QuantizationMethod, PlatformAwareOp): +@QUANTMETHODS.register(["vllm-fp8w8a8", "fp8w8a8"], platform="cuda") +class FP8w8a8QuantizationMethod(BaseQuantizationMethod): def __init__(self): super().__init__() - from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager - - self.cache_manager = g_cache_manager self.is_moe = False self.has_weight_scale = True self.has_weight_zero_point = False - @property - def method_name(self): - return "f8w8a8" - def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: - """Quantize weights using per-token FP8 quantization.""" + if self.is_moe: + return self.quantize_moe(weight, output, offset) qweight, weight_scale = scaled_fp8_quant( weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True ) @@ -179,18 +132,18 @@ def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> output.weight_scale[offset : offset + weight_scale.shape[0]].copy_(weight_scale.view(-1)) return - def create_weight( - self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 - ) -> WeightPack: - expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - if self.is_moe: - assert num_experts > 1, "Number of experts must be greater than 1 for MOE" - # per-tensor weight quantization for moe - weight_scale = torch.empty((num_experts,), dtype=torch.float32).cuda(device_id) - else: - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) - return WeightPack(weight=weight, weight_scale=weight_scale) + def quantize_moe(self, weight: torch.Tensor) -> WeightPack: + num_experts = weight.shape[0] + qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda(self.device_id_) + weight_scales = [] + for i in range(num_experts): + qweight, weight_scale = scaled_fp8_quant( + weight[i].contiguous().cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + ) + qweights[i] = qweight + weight_scales.append(weight_scale) + weight_scale = torch.stack(weight_scales, dim=0).contiguous() + return WeightPack(weight=qweights, weight_scale=weight_scale) def apply( self, @@ -200,60 +153,100 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self._forward(input_tensor, weight_pack, out, use_custom_tensor_mananger, bias) - - def _cuda_forward( - self, - input_tensor: torch.Tensor, - weight_pack: WeightPack, - out: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], ) -> torch.Tensor: qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - m = input_tensor.shape[0] n = qweight.shape[1] - if out is None: if use_custom_tensor_mananger: out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) else: out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - - cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) + if LIGHTLLM_USE_TRITON_FP8_SCALED_MM: + out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) + else: + cutlass_scaled_mm(out, x_q, qweight, x_scale, weight_scale, bias) return out - def _apply_triton( + @property + def method_name(self): + return "vllm-fp8w8a8" + + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) + + +@QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"], platform="cuda") +class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod): + def __init__(self): + super().__init__() + self.block_size = 128 + self.weight_scale_suffix = "weight_scale_inv" + self.has_weight_scale = True + self.has_weight_zero_point = False + + def quantize(self, weight: torch.Tensor, output: WeightPack, offset: int = 0) -> None: + from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant + + device = output.weight.device + weight, scale = weight_quant(weight.cuda(device), self.block_size) + output.weight[offset : offset + weight.shape[0], :].copy_(weight) + output.weight_scale[offset // self.block_size : offset + weight.shape[0] // self.block_size].copy_(scale) + return + + def apply( self, input_tensor: torch.Tensor, weight_pack: WeightPack, - out: Optional[torch.Tensor], - use_custom_tensor_mananger: bool, - bias: Optional[torch.Tensor], + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = weight_pack.weight.t() - weight_scale = weight_pack.weight_scale - - # TODO: support fp8 quantization triton - - x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) - - m = input_tensor.shape[0] + weight_scale = weight_pack.weight_scale.t() + input_scale = None # dynamic quantization for input tensor + m, k = input_tensor.shape n = qweight.shape[1] - + alloc_func = torch.empty if not use_custom_tensor_mananger else self.cache_manager.empty + if input_scale is None: + qinput_tensor, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, dtype=qweight.dtype, alloc_func=alloc_func + ) if out is None: - if use_custom_tensor_mananger: - out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device) - else: - out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + out = alloc_func((m, n), dtype=input_tensor.dtype, device=input_tensor.device) + if n % 128 != 0: + w8a8_block_fp8_matmul( + qinput_tensor, + qweight, + input_scale, + weight_scale, + out, + (self.block_size, self.block_size), + dtype=input_tensor.dtype, + ) + else: + input_scale = input_scale.t().contiguous().t() + cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias) + return out - out = fp8_scaled_mm_per_token(x_q, qweight, x_scale, weight_scale, input_tensor.dtype, out) + @property + def method_name(self): + return "vllm-fp8w8a8-b128" - if bias is not None: - out.add_(bias) - return out + def create_weight( + self, out_dim: int, in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + ) -> WeightPack: + expert_prefix = (num_experts,) if num_experts > 1 else () + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight_scale = torch.empty( + expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 + ).cuda(device_id) + return WeightPack(weight=weight, weight_scale=weight_scale) diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index 2e14eca26f..cb540ee4d5 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -15,23 +15,13 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - # Share weight with EmbeddingWeight to save memory - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.embed_tokens.weight", - data_type=self.data_type_, - shared_weight=self.wte_weight_, - ) - else: - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="lm_head.weight", - data_type=self.data_type_, - ) - + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, weight_name="model.norm.weight", diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index d273d51ad5..52f9289eb1 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -91,15 +91,13 @@ def _tpsp_get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) - cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - - layer_weight.q_norm_weight_(q.view(-1, self.head_dim_), eps=self.eps_, out=q.view(-1, self.head_dim_)) - - cache_kv[:, : self.tp_k_head_num_, :] = layer_weight.k_norm_weight_( - cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + cache_kv = layer_weight.kv_proj.mm(input) + layer_weight.q_norm_weight_(q, eps=self.eps_) + layer_weight.k_norm_weight_( + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, - alloc_func=self.alloc_tensor, - ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index 475bcee95b..b3fe64ce59 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -14,22 +14,13 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - # Share weight with EmbeddingWeight to save memory - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.language_model.embed_tokens.weight", - data_type=self.data_type_, - shared_weight=self.wte_weight_, - ) - else: - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) self.final_norm_weight_ = RMSNormWeight( dim=hidden_size, weight_name="model.language_model.norm.weight", diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index e6d5cb441d..2071b52cd5 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -14,22 +14,13 @@ def __init__(self, data_type, network_config): data_type=self.data_type_, ) tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) - if tie_word_embeddings: - # Share weight with EmbeddingWeight to save memory - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="model.embed_tokens.weight", - data_type=self.data_type_, - shared_weight=self.wte_weight_, - ) - else: - self.lm_head_weight_ = LMHeadWeight( - dim=hidden_size, - vocab_size=vocab_size, - weight_name="lm_head.weight", - data_type=self.data_type_, - ) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) self.final_norm_weight_ = LayerNormWeight( dim=hidden_size, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 0bf974df1f..f30d4cdf77 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -318,7 +318,7 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, nargs="+", choices=["None", "triton", "fa3", "flashinfer"], - default=["triton"], + default=["fa3"], help="""prefill attention kernel used in llm. None: automatically select backend based on current GPU device, not supported yet, will support in future""", @@ -328,7 +328,7 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, nargs="+", choices=["None", "triton", "fa3", "flashinfer"], - default=["triton"], + default=["fa3"], help="""decode attention kernel used in llm. None: automatically select backend based on current GPU device, not supported yet, will support in future""", From d5ac1a835b0a0bc352788974c64c85a553ebb129 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 14:00:04 +0000 Subject: [PATCH 29/43] update docs --- docs/CN/source/models/add_new_model.md | 37 -------------------------- docs/EN/source/models/add_new_model.md | 36 ------------------------- 2 files changed, 73 deletions(-) diff --git a/docs/CN/source/models/add_new_model.md b/docs/CN/source/models/add_new_model.md index 49b47ffa26..5d34cf747e 100755 --- a/docs/CN/source/models/add_new_model.md +++ b/docs/CN/source/models/add_new_model.md @@ -162,19 +162,6 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): self.tp_rank_: split_vob_size * (self.tp_rank_ + 1), :]) self.lm_head_weight_ = self.wte_weight_ return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return - ~~~ ***transformer_layer_weight.py*** @@ -204,30 +191,6 @@ class BloomTransformerLayerWeight(TransformerLayerWeight): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return def _load_qkvo_weights(self, weights): if f"h.{self.layer_num_}.input_layernorm.weight" in weights: diff --git a/docs/EN/source/models/add_new_model.md b/docs/EN/source/models/add_new_model.md index 6127dffaf7..7417c39cfe 100755 --- a/docs/EN/source/models/add_new_model.md +++ b/docs/EN/source/models/add_new_model.md @@ -162,18 +162,6 @@ class BloomPreAndPostLayerWeight(PreAndPostLayerWeight): self.tp_rank_: split_vob_size * (self.tp_rank_ + 1), :]) self.lm_head_weight_ = self.wte_weight_ return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.pre_norm_weight_, - self.pre_norm_bias_, - self.final_norm_weight_, - self.final_norm_bias_, - self.wte_weight_, - self.lm_head_weight_] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return ~~~ @@ -204,30 +192,6 @@ class BloomTransformerLayerWeight(TransformerLayerWeight): self._load_qkvo_weights(weights) self._load_ffn_weights(weights) return - - def verify_load(self): - errors = "weights load not ok" - weights = [self.att_norm_weight_, - self.att_norm_bias_, - self.q_weight_, - self.k_weight_, - self.v_weight_, - self.q_bias_, - self.k_bias_, - self.v_bias_, - self.o_weight_, - self.o_bias_, - - self.ffn_norm_weight_, - self.ffn_norm_bias_, - self.ffn_1_weight_, - self.ffn_1_bias_, - self.ffn_2_weight_, - self.ffn_2_bias_, - ] - for i in range(len(weights)): - assert weights[i] is not None, "index:" + str(i) + " " + errors - return def _load_qkvo_weights(self, weights): if f"h.{self.layer_num_}.input_layernorm.weight" in weights: From c1274eaafdc4ec41864e1055f86a5c6cd40049df Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 19 Jan 2026 14:39:50 +0000 Subject: [PATCH 30/43] fix pre-weight --- .../models/llama/layer_weights/pre_and_post_layer_weight.py | 2 +- .../qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py | 2 +- .../starcoder2/layer_weights/pre_and_post_layer_weight.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index cb540ee4d5..8efa36cf80 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -18,7 +18,7 @@ def __init__(self, data_type, network_config): self.lm_head_weight_ = LMHeadWeight( dim=hidden_size, vocab_size=vocab_size, - weight_name="model.embed_tokens.weight", + weight_name="lm_head.weight", data_type=self.data_type_, embedding_weight=self.wte_weight_ if tie_word_embeddings else None, ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py index b3fe64ce59..0ba06c6aed 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/pre_and_post_layer_weight.py @@ -17,7 +17,7 @@ def __init__(self, data_type, network_config): self.lm_head_weight_ = LMHeadWeight( dim=hidden_size, vocab_size=vocab_size, - weight_name="model.language_model.embed_tokens.weight", + weight_name="lm_head.weight", data_type=self.data_type_, embedding_weight=self.wte_weight_ if tie_word_embeddings else None, ) diff --git a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py index 2071b52cd5..cc256c442d 100644 --- a/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py @@ -17,7 +17,7 @@ def __init__(self, data_type, network_config): self.lm_head_weight_ = LMHeadWeight( dim=hidden_size, vocab_size=vocab_size, - weight_name="model.embed_tokens.weight", + weight_name="lm_head.weight", data_type=self.data_type_, embedding_weight=self.wte_weight_ if tie_word_embeddings else None, ) From 1bd148d17bbe552ef0d0e9926f817599aa221f60 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:58:34 +0800 Subject: [PATCH 31/43] fix cpu kv cache offload async error (#1180) --- .../mode_backend/multi_level_kv_cache.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index d4ba902999..b9d13f512b 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -26,11 +26,16 @@ def __init__(self, backend): self.init_sync_group = create_new_group_for_current_dp("nccl") dist.barrier(group=self.init_sync_group) + self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") + self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") + self.cpu_cache_handle_queue: Deque[TransTask] = deque() self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) # 一些算子模式需要同步计算和 cpu cache 的 load 和 offload 操作 - self.need_sync_compute_stream: bool = True + self.need_sync_compute_stream: bool = ( + "fa3" in self.args.llm_decode_att_backend or "fa3" in self.args.llm_prefill_att_backend + ) def wait(self): """ @@ -89,14 +94,18 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): cpu_kv_cache_scale = None gpu_kv_cache_scale = None + mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) + page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( + non_blocking=True + ) # 将 cpu page 的内容拷贝到 gpu 页面中 load_cpu_kv_to_gpu( - gpu_mem_indexes=mem_indexes.cuda(non_blocking=True), + gpu_mem_indexes=mem_indexes_cuda, gpu_kv_cache=mem_manager.kv_buffer, gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, - page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True), + page_indexes=page_indexes_cuda, tp_index=self.backend.rank_in_dp, tp_world_size=self.backend.dp_world_size, grid_num=grid_num, @@ -221,6 +230,12 @@ def _start_kv_cache_offload_task( page_indexes = torch.tensor(page_list, dtype=torch.int32, device="cpu", pin_memory=True) page_readies = torch.tensor(ready_list, dtype=torch.bool, device="cpu", pin_memory=True) + assert len(page_indexes) <= self.page_index_buffer.shape[0] + cuda_page_indexes = self.page_index_buffer[: len(page_indexes)] + cuda_page_readies = self.page_ready_buffer[: len(page_readies)] + cuda_page_indexes.copy_(page_indexes, non_blocking=True) + cuda_page_readies.copy_(page_readies, non_blocking=True) + move_token_num = item_size * self.args.cpu_cache_token_page_size assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num] @@ -248,8 +263,8 @@ def _start_kv_cache_offload_task( gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, - page_indexes=page_indexes, - page_readies=page_readies, + page_indexes=cuda_page_indexes, + page_readies=cuda_page_readies, tp_index=self.backend.rank_in_dp, tp_world_size=self.backend.dp_world_size, grid_num=grid_num, From 148e7f15219182b60bc48d63b63960d6d0788803 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 20 Jan 2026 07:28:19 +0000 Subject: [PATCH 32/43] fix deepseek --- lightllm/common/basemodel/basemodel.py | 2 + .../layer_weights/base_layer_weight.py | 9 ++ .../layer_weights/meta_weights/__init__.py | 1 + .../meta_weights/att_sink_weight.py | 1 + .../layer_weights/meta_weights/base_weight.py | 7 ++ .../meta_weights/embedding_weight.py | 15 +++ .../fused_moe/fused_moe_weight_tp.py | 27 +++-- .../meta_weights/mm_weight/__init__.py | 2 +- .../meta_weights/mm_weight/mm_weight.py | 71 +++++++++++++ .../meta_weights/mm_weight/rowmm_weight.py | 39 +++++++- .../layer_weights/meta_weights/norm_weight.py | 19 +++- lightllm/common/quantization/__init__.py | 2 +- lightllm/common/quantization/deepgemm.py | 18 +--- .../layer_weights/transformer_layer_weight.py | 99 +++++++------------ .../layer_weights/transformer_layer_weight.py | 9 -- 15 files changed, 223 insertions(+), 98 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 435f39a88b..2dcf0c434a 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -184,6 +184,8 @@ def _load_hf_weights(self): transformer_layer_list=self.trans_layers_weight, weight_dict=self.weight_dict, ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] return def _init_mem_manager(self): diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index 6bdeb64d20..1875e2c3b3 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -26,5 +26,14 @@ def init_static_params(self): """ pass + def verify_load(self): + """ + verify all load is ok + """ + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, BaseWeight): + assert attr.verify_load(), f"Loading {attr_name} of layers {self.layer_num_} fails." + def _cuda(self, cpu_tensor): return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index fef70acf50..b67f271ca4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -3,6 +3,7 @@ MMWeightTpl, ROWMMWeight, KVROWNMMWeight, + ROWBMMWeight, COLMMWeight, ) from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 1c22bcb7d9..32d59e66e7 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -10,6 +10,7 @@ def __init__(self, weight_name: str, data_type): self.weight_name = weight_name self.data_type_ = data_type self.weight: torch.Tensor = None + # TODO: add create weight function def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights or self.weight is not None: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 58860ab30e..da03887862 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -17,6 +17,10 @@ def load_hf_weights(self, weights): def _create_weight(self): pass + @abstractmethod + def verify_load(self): + pass + class BaseWeightTpl(BaseWeight): def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: torch.dtype = None): @@ -29,5 +33,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to def load_hf_weights(self, weights): raise NotImplementedError("load_hf_weights must implement this method") + def verify_load(self): + raise NotImplementedError("verify_load must implement this method") + def _create_weight(self) -> bool: raise NotImplementedError("create_weight must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index df9050d4fc..9737f41b29 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -28,6 +28,7 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + self.load_cnt = 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights: @@ -40,6 +41,10 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == 1 def _native_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty @@ -109,6 +114,7 @@ def __init__( self._create_weight() def _create_weight(self): + self.load_cnt = 0 if self._embedding_weight is not None: self.weight = self._embedding_weight.weight return @@ -128,6 +134,10 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" logger.info(f"loaded weight vocab_size: {self.vocab_size}") self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == 1 or self._embedding_weight is not None def _native_forward( self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty @@ -171,6 +181,7 @@ def _create_weight(self): self.weight: torch.Tensor = torch.empty( self.max_position_embeddings, self.dim, dtype=self.data_type_, device=self.device_id_ ) + self.load_cnt = 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights: @@ -182,6 +193,10 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" logger.info(f"loaded weight max_position_embeddings: {self.max_position_embeddings}") self.weight.copy_(t_weight.to(self.data_type_)) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == 1 def _native_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py index 876dc44bd2..c7892ab3ba 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py @@ -8,6 +8,7 @@ get_row_slice_mixin, get_col_slice_mixin, ) +import threading def create_tp_moe_wegiht_obj( @@ -100,6 +101,7 @@ def __init__( self.col_slicer = get_col_slice_mixin( self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ ) + self.lock = threading.Lock() self._create_weight() def _create_weight(self): @@ -107,7 +109,7 @@ def _create_weight(self): intermediate_size = self.split_inter_size # Create e_score_correction_bias - if self.e_score_correction_bias is not None: + if self.e_score_correction_bias_name is not None: self.e_score_correction_bias = torch.empty( (total_expert_num,), dtype=self.data_type_, @@ -128,6 +130,7 @@ def _create_weight(self): device_id=self.device_id_, num_experts=total_expert_num, ) + self.load_cnt = 0 def _select_experts( self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group @@ -254,13 +257,18 @@ def load_hf_weights(self, weights): # Load each expert with TP slicing for i_experts in range(self.n_routed_experts): - self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) + with self.lock: + self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) if self.w13.weight_scale is not None: - self._load_expert(i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix) + with self.lock: + self._load_expert( + i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix + ) if self.w13.weight_zero_point is not None: - self._load_expert( - i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix - ) + with self.lock: + self._load_expert( + i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix + ) def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): if self.quant_method.weight_need_quanted(weight): @@ -276,12 +284,17 @@ def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) if w1_weight in weights: load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) + self.load_cnt += 1 if w3_weight in weights: load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) - + self.load_cnt += 1 load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) if w2_weight in weights: load_func(slice_func(weights[w2_weight]), self.w2.get_expert(expert_idx), start_idx=0) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == self.n_routed_experts * 3 * 2 def _get_load_and_slice_func(self, type: str, is_row: bool = True): if is_row: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py index ae0c651977..e9ae4f30ab 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py @@ -1,5 +1,5 @@ from .mm_weight import ( MMWeightTpl, ) -from .rowmm_weight import ROWMMWeight, KVROWNMMWeight +from .rowmm_weight import ROWMMWeight, KVROWNMMWeight, ROWBMMWeight from .colmm_weight import COLMMWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 728ed82fa9..3ba4d3e592 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -127,6 +127,7 @@ def _create_weight(self): self.mm_param: WeightPack = self.quant_method.create_weight( in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() ) + self.load_cnt = 0 return # 执行顺序 @@ -140,6 +141,7 @@ def _load_weight( self.quant_method.quantize(weight, self.mm_param, offset=start_idx) else: self.quant_method.load_weight(weight, self.mm_param, start_idx) + self.load_cnt += 1 return def _load_bias( @@ -159,6 +161,7 @@ def _load_weight_scale( weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) start_idx = self.cusum_out_dims[sub_child_index] self.quant_method.load_weight_scale(weight_scale, self.mm_param, start_idx) + self.load_cnt += 1 return def _load_weight_zero_point( @@ -168,10 +171,78 @@ def _load_weight_zero_point( weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) start_idx = self.cusum_out_dims[sub_child_index] self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) + self.load_cnt += 1 return + def verify_load(self): + if self.quant_method.method_name != "none": + return self.load_cnt == len(self.weight_names) * 2 + else: + return self.load_cnt == len(self.weight_names) + def _get_tp_dim(self, dim: int) -> int: assert ( dim % self.tp_world_size_ == 0 ), f"dim must be divisible by tp_world_size_, but found: {dim} % {self.tp_world_size_}" return dim // self.tp_world_size_ + + +class BMMWeightTpl(BaseWeightTpl): + def __init__( + self, + dim0: int, + dim1: int, + dim2: int, + weight_names: Union[str, List[str]], + data_type: torch.dtype, + bias_names: Optional[Union[str, List[str]]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + super().__init__(tp_rank, tp_world_size, data_type) + if isinstance(weight_names, str): + weight_names = [weight_names] + self.weight_names = weight_names + self.bias_names = bias_names + assert bias_names is None, "bmm not support bias" + if isinstance(bias_names, list): + assert all(bias_name is None for bias_name in bias_names), "bmm not support bias" + assert quant_method is None, "bmm not support quantized weight" + self.quant_method = quant_method + self.dim0 = dim0 + self.dim1 = dim1 + self.dim2 = dim2 + self._create_weight() + return + + def _create_weight(self): + self.weight = torch.empty(self.dim0, self.dim1, self.dim2, dtype=self.data_type_).cuda(get_current_device_id()) + self.load_cnt = 0 + return + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + for weight_name in self.weight_names: + if weight_name in weights: + weight = self.param_slicer._slice_weight(weights[weight_name]) + self.weight.copy_(weight) + self.load_cnt += 1 + return + + def verify_load(self): + return self.load_cnt == len(self.weight_names) + + def bmm( + self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True + ) -> torch.Tensor: + # 目前 bmm 不支持量化运算操作 + fpweight = self.weight + if out is None: + shape = (input_tensor.shape[0], input_tensor.shape[1], fpweight.shape[2]) + dtype = input_tensor.dtype + device = input_tensor.device + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, dtype, device=device) + else: + out = torch.empty(shape, dtype=dtype, device=device) + return torch.bmm(input_tensor, fpweight, out=out) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index d7554b3757..30a699bb68 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -1,7 +1,5 @@ import torch -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import ( - MMWeightTpl, -) +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import MMWeightTpl, BMMWeightTpl from lightllm.common.quantization import Quantcfg from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.quantization.quantize_method import QuantizationMethod @@ -92,3 +90,38 @@ def _get_tp_padded_head_num(self, head_num: int): f"tp_world_size_ must be divisible by head_num, " f"but found: {head_num} % {self.tp_world_size_}" ) + + +class ROWBMMWeight(BMMWeightTpl): + def __init__( + self, + dim0: int, + dim1: int, + dim2: int, + weight_names: Union[str, List[str]], + data_type: torch.dtype, + bias_names: Optional[Union[str, List[str]]] = None, + quant_method: QuantizationMethod = None, + tp_rank: int = None, + tp_world_size: int = None, + ) -> None: + self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() + self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + assert ( + dim0 % self.tp_world_size_ == 0 + ), f"dim0 of bmm must be divisible by tp_world_size_, but found: {dim0} % {self.tp_world_size_}" + dim0 = dim0 // self.tp_world_size_ + super().__init__( + dim0=dim0, + dim1=dim1, + dim2=dim2, + weight_names=weight_names, + bias_names=bias_names, + data_type=data_type, + quant_method=quant_method, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + ) + self.param_slicer = get_row_slice_mixin( + quant_method_name="none", tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index d7bbe5567a..1a8f59723b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -19,10 +19,15 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.load_cnt = 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == 1 def _native_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty @@ -42,7 +47,9 @@ def _native_forward( def _triton_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - assert input.ndim == 2 and self.weight.ndim == 1 + assert ( + input.ndim in [2, 3] and self.weight.ndim == 1 + ), f"input.ndim: {input.ndim} != 2 or weight.ndim: {self.weight.ndim} != 1" if out is None: out = alloc_func(input.shape, dtype=input.dtype, device=input.device) return rmsnorm_forward(x=input, weight=self.weight, eps=eps, out=out) @@ -77,12 +84,18 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) self.bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.load_cnt = 0 def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) + self.load_cnt += 1 if self.bias_name in weights: self.bias.copy_(weights[self.bias_name]) + self.load_cnt += 1 + + def verify_load(self): + return self.load_cnt == 2 def _native_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty @@ -162,6 +175,7 @@ def load_hf_weights(self, weights): self.weight[:, end - start].copy_(t_weight[start:end].to(self.data_type_)) # the padding part is zero self.weight[:, end:].zero_() + self.load_cnt += 1 class NoTpGEMMANormWeight(RMSNormWeight): @@ -173,7 +187,8 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) - self.weight += 1 + self.weight += 1 + self.load_cnt += 1 class QKRMSNORMWeight(RMSNormWeight): diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index bf99622ef2..1e47454490 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -36,7 +36,7 @@ def _mapping_quant_method(self): if self.hf_quantization_method == "fp8": block_size = self.hf_quantization_config.get("weight_block_size", None) if block_size == [128, 128]: - from lightllm.common.quantization.deepgemm_quant import HAS_DEEPGEMM + from lightllm.common.quantization.deepgemm import HAS_DEEPGEMM if HAS_DEEPGEMM: self.quant_type = "deepgemm-fp8w8a8-b128" diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index c9f227120d..80be14c335 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -48,7 +48,7 @@ class DeepGEMMFP8w8a8B128QuantizationMethod(DeepGEMMBaseQuantizationMethod): def __init__(self): super().__init__() self.block_size = 128 - self.weight_suffix = None + self.weight_suffix = "weight" self.weight_zero_point_suffix = None self.weight_scale_suffix = "weight_scale_inv" self.has_weight_scale = True @@ -102,9 +102,9 @@ def create_weight( ) -> WeightPack: expert_prefix = (num_experts,) if num_experts > 1 else () weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty( - expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) + scale_out_dim = (out_dim + self.block_size - 1) // self.block_size + scale_in_dim = (in_dim + self.block_size - 1) // self.block_size + weight_scale = torch.empty(expert_prefix + (scale_out_dim, scale_in_dim), dtype=torch.float32).cuda(device_id) return WeightPack(weight=weight, weight_scale=weight_scale) def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: @@ -112,15 +112,7 @@ def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: return def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_scale[ - start_idx // self.block_size : start_idx + weight_scale.shape[0] // self.block_size - ].copy_(weight_scale) - return - - def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack, start_idx: int) -> None: - weight_pack.weight_zero_point[ - start_idx // self.block_size : start_idx + weight_zero_point.shape[0] // self.block_size - ].copy_(weight_zero_point) + weight_pack.weight_scale[start_idx // self.block_size : start_idx + weight_scale.shape[0]].copy_(weight_scale) return diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index d927f22d1b..1e8d572e15 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -6,6 +6,7 @@ from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, + ROWBMMWeight, COLMMWeight, RMSNormWeight, FusedMoeWeightEP, @@ -65,31 +66,14 @@ def _init_weight(self): self._init_ffn() self._init_norm() - def _load_kb(self, kv_b_proj_): - k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[ - :, : self.qk_nope_head_dim, : - ] - return k_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_kb_scale(self, kv_b_proj_, block_size): - k_b_proj_scale_ = kv_b_proj_.view( - self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size - )[:, : self.qk_nope_head_dim // block_size, :] - return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb(self, kv_b_proj_): - v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[ - :, :, self.qk_nope_head_dim : - ].transpose(0, 1) - return v_b_proj_.contiguous().to(kv_b_proj_.dtype) - - def _load_vb_scale(self, kv_b_proj_scale_, block_size): - v_b_proj_scale_ = kv_b_proj_scale_.T.view( - self.kv_lora_rank // block_size, - self.num_attention_heads, - self.qk_nope_head_dim * 2 // block_size, - )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) - return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + def _split_kv_b_proj(self, kv_b_proj_): + kv_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank) + k_b_proj_, v_b_proj_ = torch.split(kv_b_proj_, [self.qk_nope_head_dim, self.v_head_dim], dim=-2) + # num_attention_heads x qk_nope_head_dim x kv_lora_rank + k_b_proj_ = k_b_proj_.contiguous().to(kv_b_proj_.dtype) + # num_attention_heads x kv_lora_rank x v_head_dim + v_b_proj_ = v_b_proj_.transpose(1, 2).contiguous().to(kv_b_proj_.dtype) + return k_b_proj_, v_b_proj_ def _rename_shared_experts(self, weights, weight_scale_suffix): # 将共享专家对应的参数,改造为与路由专家一致的权重名称和映射关系。 @@ -122,21 +106,9 @@ def load_hf_weights(self, weights): kv_b_proj_.cuda(), weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() - weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) - weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) - - if ( - self.quant_cfg.quantized_weight - and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix in weights - ): - kv_b_proj_scale_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix] - block_size = 128 - weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + weight_scale_suffix] = self._load_kb_scale( - kv_b_proj_scale_, block_size - ) - weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + weight_scale_suffix] = self._load_vb_scale( - kv_b_proj_scale_, block_size - ) + k_b_proj_, v_b_proj_ = self._split_kv_b_proj(kv_b_proj_) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = v_b_proj_ # rename the shared experts weight if self.num_fused_shared_experts > 0: @@ -181,20 +153,22 @@ def _init_qkvo(self): data_type=self.data_type_, quant_method=self.get_quant_method("q_b_proj"), ) - # self.k_b_proj_ = ROWBMMWeight( - # weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", - # data_type=self.data_type_, - # quant_cfg=None, - # layer_num=self.layer_num_, - # name="k_b_proj", - # ) - # self.v_b_proj_ = ROWBMMWeight( - # weight_names=f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", - # data_type=self.data_type_, - # quant_cfg=None, - # layer_num=self.layer_num_, - # name="v_b_proj", - # ) + self.k_b_proj_ = ROWBMMWeight( + dim0=self.num_attention_heads, + dim1=self.qk_nope_head_dim, + dim2=self.kv_lora_rank, + weight_names=f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight", + data_type=self.data_type_, + quant_method=None, + ) + self.v_b_proj_ = ROWBMMWeight( + dim0=self.num_attention_heads, + dim1=self.kv_lora_rank, + dim2=self.v_head_dim, + weight_names=f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight", + data_type=self.data_type_, + quant_method=None, + ) if self.enable_cc_method: self.cc_kv_b_proj_ = ROWMMWeight( in_dim=self.kv_lora_rank, @@ -212,12 +186,13 @@ def _init_qkvo(self): quant_method=self.get_quant_method("o_weight"), ) - def _load_mlp(self, mlp_prefix): + def _load_mlp(self, mlp_prefix, is_shared_experts=False): moe_mode = os.getenv("MOE_MODE", "TP") + mlp_inter = self.moe_inter if is_shared_experts else self.n_inter if self.is_moe and moe_mode == "EP": self.gate_up_proj = ROWMMWeight( in_dim=self.n_embed, - out_dims=[self.moe_inter, self.moe_inter], + out_dims=[mlp_inter, mlp_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, quant_method=self.get_quant_method("gate_up_proj"), @@ -225,7 +200,7 @@ def _load_mlp(self, mlp_prefix): tp_world_size=1, ) self.down_proj = COLMMWeight( - in_dim=self.moe_inter, + in_dim=mlp_inter, out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, @@ -236,13 +211,13 @@ def _load_mlp(self, mlp_prefix): else: self.gate_up_proj = ROWMMWeight( in_dim=self.n_embed, - out_dims=[self.n_inter, self.n_inter], + out_dims=[mlp_inter, mlp_inter], weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( - in_dim=self.n_inter, + in_dim=mlp_inter, out_dims=[self.n_embed], weight_names=f"{mlp_prefix}.down_proj.weight", data_type=self.data_type_, @@ -256,7 +231,7 @@ def _init_moe(self): out_dims=[self.n_routed_experts], weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, - quant_method=self.get_quant_method("moe_gate"), + quant_method=None, tp_rank=0, tp_world_size=1, ) @@ -267,7 +242,7 @@ def _init_moe(self): # 专家对应的 gate_up_proj 等weight 参数。当 num_fused_shared_experts # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: - self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts") + self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) moe_mode = os.getenv("MOE_MODE", "TP") assert moe_mode in ["EP", "TP"] if moe_mode == "TP": @@ -318,7 +293,7 @@ def _init_norm(self): data_type=self.data_type_, ) self.kv_a_layernorm_ = RMSNormWeight( - dim=self.kv_lora_rank + self.qk_rope_head_dim, + dim=self.kv_lora_rank, weight_name=f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight", data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 17023d0cb7..54cf7f02db 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -32,15 +32,6 @@ def _init_weight_names(self): self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" self._ffn_norm_bias_name = None - def load_hf_weights(self, weights): - kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") - if self.quant_cfg.quantized_weight: - _k_scale_weight_name = self._k_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix) - self._repeat_weight(_k_scale_weight_name, weights) - _v_scale_weight_name = self._v_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix) - self._repeat_weight(_v_scale_weight_name, weights) - return super().load_hf_weights(weights) - def _init_weight(self): self._init_qkv() self._init_o() From cb4aecc3d4df9fb238961dff647c432ac7657cfc Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 20 Jan 2026 14:16:46 +0000 Subject: [PATCH 33/43] fix unitest --- .../common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py index 671805a3d2..ab2cc4976a 100644 --- a/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py +++ b/unit_tests/common/fused_moe/test_moe_silu_and_mul_mix_quant_ep.py @@ -16,7 +16,7 @@ def is_fp8_native_supported(): import random from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) From 9242821752e2c47bdc007104a908cb21acf9b49f Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Tue, 13 Jan 2026 13:58:15 +0800 Subject: [PATCH 34/43] [MUSA] Add shell script to generate requirements-musa.txt and update doc (#1175) --- .gitignore | 1 + .../source/getting_started/installation.rst | 12 +- .../source/getting_started/installation.rst | 22 ++-- generate_requirements_musa.sh | 105 ++++++++++++++++++ 4 files changed, 127 insertions(+), 13 deletions(-) create mode 100755 generate_requirements_musa.sh diff --git a/.gitignore b/.gitignore index 6049c2cdbe..63408699f4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist .idea .vscode tmp/ +requirements-musa.txt diff --git a/docs/CN/source/getting_started/installation.rst b/docs/CN/source/getting_started/installation.rst index fb998b7567..5fa0e304d2 100755 --- a/docs/CN/source/getting_started/installation.rst +++ b/docs/CN/source/getting_started/installation.rst @@ -27,7 +27,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 前请确保你的docker设置中已经分配了足够的共享内存,否则可能导致 $ # 服务无法正常启动。 $ # 1.如果是纯文本服务,建议分配2GB以上的共享内存, 如果你的内存充足,建议分配16GB以上的共享内存. - $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. + $ # 2.如果是多模态服务,建议分配16GB以上的共享内存,具体可以根据实际情况进行调整. $ # 如果你没有足够的共享内存,可以尝试在启动服务的时候调低 --running_max_req_size 参数,这会降低 $ # 服务的并发请求数量,但可以减少共享内存的占用。如果是多模态服务,也可以通过降低 --cache_capacity $ # 参数来减少共享内存的占用。 @@ -38,7 +38,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 你也可以使用源码手动构建镜像并运行,建议手动构建镜像,因为更新比较频繁: .. code-block:: console - + $ # 进入代码仓库的根目录 $ cd /lightllm $ # 手动构建镜像, docker 目录下有不同功能场景的镜像构建文件,按需构建。 @@ -52,7 +52,7 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton 或者你也可以直接使用脚本一键启动镜像并且运行: .. code-block:: console - + $ # 查看脚本参数 $ python tools/quick_launch_docker.py --help @@ -80,6 +80,10 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton $ # 安装lightllm的依赖 (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # 安装lightllm的依赖 (摩尔线程 GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # 安装lightllm $ python setup.py install @@ -97,6 +101,6 @@ Lightllm 是一个纯python开发的推理框架,其中的算子使用triton .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + 具体原因可以参考:`issue `_ 和 `fix PR `_ diff --git a/docs/EN/source/getting_started/installation.rst b/docs/EN/source/getting_started/installation.rst index 75fa714764..6439c48de3 100755 --- a/docs/EN/source/getting_started/installation.rst +++ b/docs/EN/source/getting_started/installation.rst @@ -24,16 +24,16 @@ The easiest way to install Lightllm is using the official image. You can directl $ docker pull ghcr.io/modeltc/lightllm:main $ $ # Run,The current LightLLM service relies heavily on shared memory. - $ # Before starting, please make sure that you have allocated enough shared memory + $ # Before starting, please make sure that you have allocated enough shared memory $ # in your Docker settings; otherwise, the service may fail to start properly. $ # - $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. + $ # 1. For text-only services, it is recommended to allocate more than 2GB of shared memory. $ # If your system has sufficient RAM, allocating 16GB or more is recommended. - $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. + $ # 2.For multimodal services, it is recommended to allocate 16GB or more of shared memory. $ # You can adjust this value according to your specific requirements. $ # - $ # If you do not have enough shared memory available, you can try lowering - $ # the --running_max_req_size parameter when starting the service. + $ # If you do not have enough shared memory available, you can try lowering + $ # the --running_max_req_size parameter when starting the service. $ # This will reduce the number of concurrent requests, but also decrease shared memory usage. $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ @@ -42,13 +42,13 @@ The easiest way to install Lightllm is using the official image. You can directl You can also manually build the image from source and run it: .. code-block:: console - + $ # move into lightllm root dir $ cd /lightllm $ # Manually build the image $ docker build -t -f ./docker/Dockerfile . $ - $ # Run, + $ # Run, $ docker run -it --gpus all -p 8080:8080 \ $ --shm-size 2g -v your_local_path:/data/ \ $ /bin/bash @@ -56,7 +56,7 @@ You can also manually build the image from source and run it: Or you can directly use the script to launch the image and run it with one click: .. code-block:: console - + $ # View script parameters $ python tools/quick_launch_docker.py --help @@ -84,6 +84,10 @@ You can also install Lightllm from source: $ # Install Lightllm dependencies (cuda 12.4) $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu124 $ + $ # Install Lightllm dependencies (Moore Threads GPU) + $ ./generate_requirements_musa.sh + $ pip install -r requirements-musa.txt + $ $ # Install Lightllm $ python setup.py install @@ -101,5 +105,5 @@ You can also install Lightllm from source: .. code-block:: console $ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps - + For specific reasons, please refer to: `issue `_ and `fix PR `_ \ No newline at end of file diff --git a/generate_requirements_musa.sh b/generate_requirements_musa.sh new file mode 100755 index 0000000000..f5bfb8ff83 --- /dev/null +++ b/generate_requirements_musa.sh @@ -0,0 +1,105 @@ +#!/bin/bash +# Script to generate requirements-musa.txt from requirements.txt +# MUSA is not compatible with CUDA packages, so they need to be removed +# Torch-related packages are pre-installed in the MUSA docker container + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +INPUT_FILE="${SCRIPT_DIR}/requirements.txt" +OUTPUT_FILE="${SCRIPT_DIR}/requirements-musa.txt" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: requirements.txt not found at $INPUT_FILE" + exit 1 +fi + +echo "Generating requirements-musa.txt from requirements.txt..." + +# Define patterns to remove (CUDA-specific packages) +# These packages are not compatible with MUSA +CUDA_PACKAGES=( + "^cupy" # cupy-cuda12x and similar + "^cuda_bindings" # CUDA bindings + "^nixl" # NIXL (NVIDIA Inter-node eXchange Library) + "^flashinfer" # flashinfer-python (CUDA-specific attention kernel) + "^sgl-kernel" # SGL kernel (CUDA-specific) +) + +# Define torch-related packages (pre-installed in MUSA container, remove version pins) +TORCH_PACKAGES=( + "^torch==" + "^torchvision==" +) + +# Create the output file with a header comment +cat > "$OUTPUT_FILE" << 'EOF' +# Requirements for MUSA (Moore Threads GPU) +# Auto-generated from requirements.txt by generate_requirements_musa.sh +# CUDA-specific packages have been removed +# Torch-related packages have version pins removed (pre-installed in MUSA container) + +EOF + +# Process the requirements file +while IFS= read -r line || [ -n "$line" ]; do + # Skip empty lines and comments (but keep them in output) + if [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]]; then + echo "$line" >> "$OUTPUT_FILE" + continue + fi + + # Extract package name (before ==, >=, <=, ~=, etc.) + pkg_name=$(echo "$line" | sed -E 's/^([a-zA-Z0-9_-]+).*/\1/') + + # Check if this is a CUDA package to skip + skip=false + for pattern in "${CUDA_PACKAGES[@]}"; do + if [[ "$pkg_name" =~ $pattern ]]; then + echo " Removing CUDA package: $line" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Check if this is a torch-related package (remove version pin) + for pattern in "${TORCH_PACKAGES[@]}"; do + if [[ "$line" =~ $pattern ]]; then + # Remove version pin, keep just the package name + pkg_only=$(echo "$line" | sed -E 's/==.*//') + echo " Unpinning version for: $pkg_only (pre-installed in MUSA container)" + echo "$pkg_only" >> "$OUTPUT_FILE" + skip=true + break + fi + done + + if $skip; then + continue + fi + + # Keep the package as-is + echo "$line" >> "$OUTPUT_FILE" + +done < "$INPUT_FILE" + +# Add MUSA-specific packages at the end +cat >> "$OUTPUT_FILE" << 'EOF' + +# MUSA-specific packages +torch_musa +torchada +EOF + +echo "" +echo "Successfully generated: $OUTPUT_FILE" +echo "" +echo "Summary of changes:" +echo " - Removed CUDA-specific packages: cupy-cuda12x, cuda_bindings, nixl, flashinfer-python, sgl-kernel" +echo " - Unpinned torch-related packages: torch, torchvision (pre-installed in MUSA container)" +echo " - Added MUSA-specific packages: torch_musa, torchada" + From 35902c4bf3160e410979d2c063df7f6957b91754 Mon Sep 17 00:00:00 2001 From: sufubao <47234901+sufubao@users.noreply.github.com> Date: Tue, 13 Jan 2026 22:56:45 +0800 Subject: [PATCH 35/43] fix openai v1 (#1178) Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> --- lightllm/models/qwen2/model.py | 2 +- lightllm/server/api_models.py | 5 +++-- lightllm/server/api_openai.py | 2 +- lightllm/server/core/objs/sampling_params.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index d2f067c42c..106610ff09 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -18,7 +18,7 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - if self.config["sliding_window"] is None: + if self.config.get("sliding_window", None) is None: self.config["sliding_window"] = self.max_total_token_num # rename key [SYM: to be confirmed] return diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7b9cdd5012..f30ecc55fe 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -24,9 +24,10 @@ class Message(BaseModel): class Function(BaseModel): """Function descriptions.""" - description: Optional[str] = Field(default=None, examples=[None]) name: Optional[str] = None - parameters: Optional[object] = None + description: Optional[str] = Field(default=None, examples=[None]) + parameters: Optional[dict] = None + response: Optional[dict] = None class Tool(BaseModel): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 6a8c232dc5..d91bb1d947 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -81,7 +81,7 @@ def _process_tool_call_id( # SGLang sets call_item.tool_index to the *local* position inside that message. # Therefore, the index must be corrected by using # `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered. - tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}" + tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt + call_item.tool_index}" logger.debug( f"Process tool call idx, parser: {tool_call_parser}, \ tool_call_id: {tool_call_id}, \ diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index f073319d79..d955aa6a87 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -7,7 +7,7 @@ _SAMPLING_EPS = 1e-5 DEFAULT_INPUT_PENALTY = os.getenv("INPUT_PENALTY", "False").upper() in ["ON", "TRUE", "1"] -SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "True").upper() in ["ON", "TRUE", "1"] +SKIP_SPECIAL_TOKENS = os.getenv("SKIP_SPECIAL_TOKENS", "False").upper() in ["ON", "TRUE", "1"] # 从环境变量获取最大长度限制 STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256)) From 0f89a7affe59e593665f434bbb7755b8eaa9f723 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Thu, 15 Jan 2026 13:12:48 +0800 Subject: [PATCH 36/43] add diverse_stage2 add optimize diverse_stage1 (#1174) Co-authored-by: wangzaijun --- ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + ...h.float16,q_head_dim=128}_NVIDIA_H200.json | 1 + ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 1 + ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 1 + .../basemodel/attention/triton/int8kv.py | 2 +- ...se.py => int8kv_flash_decoding_diverse.py} | 38 +-- ...> int8kv_flash_decoding_diverse_stage1.py} | 32 +- .../int8kv_flash_decoding_diverse_stage2.py | 306 ++++++++++++++++++ ...> int8kv_flash_decoding_diverse_stage3.py} | 0 ...ad_dim=128}_NVIDIA_GeForce_RTX_4090_D.json | 22 ++ ...head_dim=128}_NVIDIA_GeForce_RTX_5090.json | 22 ++ .../benchmark/static_inference/model_infer.py | 5 +- .../llama_gqa_diverse_decode_stage1_tuning.py | 2 +- .../llama_gqa_diverse_decode_stage2_tuning.py | 296 +++++++++++++++++ ... => test_int8kv_flash_decoding_diverse.py} | 27 +- ...t_int8kv_flash_decoding_diverse_stage1.py} | 108 ++++++- ...st_int8kv_flash_decoding_diverse_stage2.py | 293 +++++++++++++++++ ...t_int8kv_flash_decoding_diverse_stage3.py} | 2 +- ...pl_int8kv_flash_decoding_diverse_stage2.py | 132 -------- 77 files changed, 1160 insertions(+), 189 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse.py => int8kv_flash_decoding_diverse.py} (75%) rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse_stage1.py => int8kv_flash_decoding_diverse_stage1.py} (90%) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/{ppl_int8kv_flash_decoding_diverse_stage3.py => int8kv_flash_decoding_diverse_stage3.py} (100%) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json create mode 100644 test/kernel/llama_gqa_diverse_decode_stage2_tuning.py rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse.py => test_int8kv_flash_decoding_diverse.py} (85%) rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse_stage1.py => test_int8kv_flash_decoding_diverse_stage1.py} (53%) create mode 100644 unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py rename unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/{test_ppl_int8kv_flash_decoding_diverse_stage3.py => test_int8kv_flash_decoding_diverse_stage3.py} (96%) delete mode 100644 unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..67dd3852c5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..555386ebdd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..6f92439c1c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..67dd3852c5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..555386ebdd --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..6f92439c1c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7f69e86a86 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7d8dc868c3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..3e543b2ea3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7f69e86a86 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7d8dc868c3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..3e543b2ea3 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=16,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..a3b0edde6d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..22d1ce6f69 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..328fcec837 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..a3b0edde6d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..22d1ce6f69 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..328fcec837 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..4c4ae86241 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..61884a9375 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..037bfd2913 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..4c4ae86241 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..61884a9375 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..037bfd2913 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=2,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..e2028e2d2a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7e99dc1be2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d6b46dda8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..e2028e2d2a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7e99dc1be2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d6b46dda8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 4}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7795b47e72 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..ff4d6efd49 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..1d8ca6967b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..7795b47e72 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..ff4d6efd49 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..1d8ca6967b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=4,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..2f1cd5dfd5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a369088bfb --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..cb4e6a0d3e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..2f1cd5dfd5 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a369088bfb --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..cb4e6a0d3e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..60827b791e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7b42cad466 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..9bb49d70b7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..60827b791e --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..7b42cad466 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..9bb49d70b7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=5,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..bd3d1c418b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..f1b3539f5d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..e12c05b966 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..bd3d1c418b --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..f1b3539f5d --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 1}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 2, "num_stages": 2}, "128": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..e12c05b966 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=4,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 4}, "256": {"BLOCK_N": 64, "num_warps": 2, "num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..c83dca52d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..59a3e1051c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b7c4eaa9f --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..c83dca52d2 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..59a3e1051c --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}, "8192": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "128": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "256": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b7c4eaa9f --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage1:v2/{block_seq=256,gqa_group_size=8,max_batch_group_size=8,out_dtype=torch.float16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"4096": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}, "8192": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 5}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..abd760af04 --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 3}, "16": {"BLOCK_N": 16, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "64": {"8": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 4}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 1}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 9}, "32": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..a560ce9e1a --- /dev/null +++ b/lightllm/common/all_kernel_configs/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1 @@ +{"32": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 2}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 11}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "64": {"8": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 7}, "16": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 10}, "32": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 4}}, "128": {"8": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 3}, "16": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 7}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 16, "num_warps": 2, "num_stages": 3}}, "256": {"8": {"BLOCK_N": 64, "num_warps": 16, "num_stages": 2}, "16": {"BLOCK_N": 64, "num_warps": 8, "num_stages": 5}, "32": {"BLOCK_N": 64, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_N": 32, "num_warps": 4, "num_stages": 2}}} \ No newline at end of file diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index 6a795c4376..975d7b629c 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -158,7 +158,7 @@ def diverse_decode_att( alloc_func=torch.empty, ) -> torch.Tensor: - from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + from ...triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py similarity index 75% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index 6efb030ce6..ad6a8b5b3a 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -2,8 +2,9 @@ import torch from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.infer_struct import InferStateInfo -from .ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 -from .ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 +from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size @@ -37,10 +38,10 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" ) current_stream = torch.cuda.current_stream() @@ -65,21 +66,20 @@ def token_decode_attention_flash_decoding( ) stream2.wait_stream(current_stream) with torch.cuda.stream(stream2): - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - BLOCK_SEQ, - mid_o, - mid_o_logexpsum, - 1.0 / (head_dim ** 0.5), - q.view(calcu_shape1), - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.b_shared_seq_len, - infer_state.max_kv_seq_len, + flash_decode_stage2( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + b_shared_seq_len=infer_state.b_shared_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, ) current_stream.wait_stream(stream1) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py similarity index 90% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 7403f6dd5c..4dfaffef68 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -9,7 +9,7 @@ class GQADiverseDecodeStage1KernelConfig(KernelConfigs): - kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v1" + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage1:v2" @classmethod @lru_cache(maxsize=200) @@ -113,6 +113,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( BLOCK_N: tl.constexpr, BLOCK_BATCH: tl.constexpr, KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, ): cur_batch = tl.program_id(0) shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch) @@ -128,6 +129,7 @@ def _fwd_kernel_flash_decode_diverse_stage1( cur_q_head_range = tl.where(cur_q_head_range < q_head_end_index, cur_q_head_range, cur_kv_head * gqa_group_size) offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) cur_batch_seq_len = tl.load(b_shared_seq_len + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_start_index = seq_start_block * BLOCK_SEQ @@ -162,25 +164,37 @@ def _fwd_kernel_flash_decode_diverse_stage1( mask=n_mask, other=0, ).to(tl.int64) - off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] - off_k_scale = off_k // KV_QUANT_GROUP_SIZE + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) att_value = tl.dot(q, k.to(q.dtype)) att_value *= sm_scale att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] v = tl.load( - V + off_k.T, + V + off_v, mask=n_mask[:, None], other=0, ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) v_scale = tl.load( - V_scale + off_k_scale.T, - mask=n_mask[:, None], + V_scale + off_k_scale, + mask=n_mask[None, :], other=0.0, ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) cur_max_logic = tl.max(att_value, axis=1) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -274,7 +288,10 @@ def flash_decode_stage1( BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) if BLOCK_HEAD * BLOCK_BATCH < 16: BLOCK_BATCH = 16 // BLOCK_HEAD - + assert k.stride() == v.stride() + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, @@ -314,6 +331,7 @@ def flash_decode_stage1( BLOCK_N=BLOCK_N, BLOCK_BATCH=BLOCK_BATCH, KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 0000000000..f5c0b9c395 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,306 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict +from lightllm.common.triton_utils.autotuner import autotune, Autotuner + + +class GQADiverseDecodeStage2KernelConfig(KernelConfigs): + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage2:v1" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + ) -> dict: + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + batch_size_config: dict = finded_config[ + min( + finded_config.keys(), + key=lambda x: abs(int(x) - avg_seq_len_in_batch), + ) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + config_json: Dict[int, Dict[int, Dict]], + ): + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def _fwd_kernel_flash_decode_diverse_stage2( + Q, + stride_qbs, + stride_qh, + stride_qd, + K, + K_scale, + stride_kbs, + stride_kh, + stride_kd, + V, + V_scale, + stride_vbs, + stride_vh, + stride_vd, + sm_scale, + Req_to_tokens, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + B_req_idx, + B_Seqlen, + b_shared_seq_len, + Mid_O, # [batch, head, seq_block_num, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, gqa_group_size) + + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) + cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_shared_len + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + store_seq_block = seq_start_block + tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + + block_n_size = tl.cdiv( + tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), + BLOCK_N, + ) + + if block_n_size == 0: + return + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + q = tl.load(Q + off_q) + + sum_exp = tl.zeros([gqa_group_size], dtype=tl.float32) + max_logic = tl.zeros([gqa_group_size], dtype=tl.float32) - float("inf") + acc = tl.zeros([gqa_group_size, BLOCK_HEADDIM], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + # q (4, 128) k (128, BLOCK_N) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + store_seq_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + store_seq_block + tl.store( + Mid_O + off_mid_o, + (acc / sum_exp[:, None]), + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + (max_logic + tl.log(sum_exp)), + ) + return + + +@torch.no_grad() +def flash_decode_stage2( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + Req_to_tokens: torch.Tensor, + B_req_idx: torch.Tensor, + B_Seqlen: torch.Tensor, + b_shared_seq_len: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, + run_config: Optional[dict] = None, +): + if not run_config: + run_config = GQADiverseDecodeStage2KernelConfig.try_to_get_best_config( + batch_size=int(q.shape[0]), + avg_seq_len_in_batch=max_len_in_batch, + gqa_group_size=int(q.shape[1] // k.shape[1]), + q_head_dim=int(q.shape[2]), + block_seq=block_seq, + out_dtype=q.dtype, + ) + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + BLOCK_SEQ = block_seq + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + assert triton.next_power_of_2(Lk) == Lk + KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] + assert KV_QUANT_GROUP_SIZE == 8 + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + + assert k.stride() == v.stride() + + _fwd_kernel_flash_decode_diverse_stage2[grid]( + Q=q, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + K=k, + K_scale=k_scale, + stride_kbs=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + V=v, + V_scale=v_scale, + stride_vbs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + sm_scale=sm_scale, + Req_to_tokens=Req_to_tokens, + stride_req_to_tokens_b=Req_to_tokens.stride(0), + stride_req_to_tokens_s=Req_to_tokens.stride(1), + B_req_idx=B_req_idx, + B_Seqlen=B_Seqlen, + b_shared_seq_len=b_shared_seq_len, + Mid_O=mid_out, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] + stride_mid_o_eb=mid_out_logsumexp.stride(0), + stride_mid_o_eh=mid_out_logsumexp.stride(1), + stride_mid_o_es=mid_out_logsumexp.stride(2), + gqa_group_size=gqa_group_size, + BLOCK_SEQ=block_seq, + BLOCK_HEADDIM=Lk, + BLOCK_N=BLOCK_N, + KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, + num_warps=num_warps, + num_stages=num_stages, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/ppl_int8kv_flash_decoding_diverse_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage3.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json new file mode 100644 index 0000000000..9f44ee6c30 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_4090_D/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_4090_D.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json new file mode 100644 index 0000000000..4fa2f949f2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_GeForce_RTX_5090/_fwd_kernel_flash_decode_diverse_stage2:v1/{block_seq=256,gqa_group_size=4,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_GeForce_RTX_5090.json @@ -0,0 +1,22 @@ +{ + "16_8192": { + "BLOCK_N": 32, + "num_stages": 2, + "num_warps": 4 + }, + "32_8192": { + "BLOCK_N": 32, + "num_stages": 4, + "num_warps": 4 + }, + "64_8192": { + "BLOCK_N": 16, + "num_stages": 3, + "num_warps": 2 + }, + "8_8192": { + "BLOCK_N": 16, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index a8abd2ae64..7f1c2b493f 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -41,7 +41,10 @@ def test_model_inference(args): "run_mode": "normal", "max_seq_length": args.max_req_total_len, "disable_cudagraph": args.disable_cudagraph, - "mode": args.mode, + "llm_prefill_att_backend": args.llm_prefill_att_backend, + "llm_decode_att_backend": args.llm_decode_att_backend, + "llm_kv_type": args.llm_kv_type, + "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } proc = multiprocessing.Process( target=tppart_model_infer, diff --git a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py index d391d30650..f32b093448 100644 --- a/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py +++ b/test/kernel/llama_gqa_diverse_decode_stage1_tuning.py @@ -4,7 +4,7 @@ import torch.multiprocessing as mp from typing import List from lightllm.utils.log_utils import init_logger -from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage1 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( flash_decode_stage1, GQADiverseDecodeStage1KernelConfig, ) diff --git a/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py new file mode 100644 index 0000000000..13c8945e59 --- /dev/null +++ b/test/kernel/llama_gqa_diverse_decode_stage2_tuning.py @@ -0,0 +1,296 @@ +import torch +import os +import torch.multiprocessing as mp +from typing import List +from lightllm.utils.log_utils import init_logger +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, + GQADiverseDecodeStage2KernelConfig, +) +from lightllm.utils.watchdog_utils import Watchdog + +logger = init_logger(__name__) + + +def set_seed(): + import torch + import random + import numpy as np + + seed = 42 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return + + +@torch.no_grad() +def test_decode_attentions( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int = 20, + **run_config, +): + set_seed() + shared_seq_len = 0 + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + max_len_in_batch = 8192 + quant_group_size = 8 + + args = [] + for _ in range(test_count): + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=dtype, device="cuda") / 10 + kv_shape = (batch_size * seq_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=dtype, device="cuda") + Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, seq_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), + dtype=q.dtype, + device="cuda", + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), + dtype=q.dtype, + device="cuda", + ) + arg_list, kwargs = ( + q, + k, + k_scale, + v, + v_scale, + Req_to_tokens, + B_req_idx, + b_seq_len, + b_shared_seq_len, + max_len_in_batch, + mid_out, + mid_out_logsumexp, + block_seq, + ), dict(run_config=run_config) + args.append((arg_list, kwargs)) + + graph = torch.cuda.CUDAGraph() + arg_list, kwargs = args[0] + flash_decode_stage2(*arg_list, **kwargs) + with torch.cuda.graph(graph): + for index in range(test_count): + arg_list, kwargs = args[index] + flash_decode_stage2(*arg_list, **kwargs) + + graph.replay() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + + cost_time = start_event.elapsed_time(end_event=end_event) + + logger.info(f"bf16 {seq_len} cost time: {cost_time} ms") + return cost_time + + +def worker( + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, + test_configs, + queue, +): + dog = Watchdog(timeout=10) + dog.start() + + try: + for index in range(len(test_configs)): + tuning_config = test_configs[index] + cost_time = test_decode_attentions( + block_seq=block_seq, + batch_size=batch_size, + seq_len=seq_len, + dtype=dtype, + test_count=test_count, + **tuning_config, + ) + dog.heartbeat() + queue.put(cost_time) + except Exception as ex: + logger.error(str(ex) + f" config {tuning_config} batch_size {batch_size} seq_len {seq_len} dtype {dtype}") + import sys + import traceback + + traceback.print_exc() + sys.exit(-1) + pass + + +def get_test_configs(split_id, split_count): + index = 0 + for block_n in [16, 32, 64]: + for num_warps in [ + 2, + 4, + 8, + 16, + ]: + for num_stages in [ + 1, + 2, + 3, + 4, + 5, + 7, + 9, + 10, + 11, + ]: + t_config = { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + if index % split_count == split_id: + yield t_config + index += 1 + else: + index += 1 + + +def tuning_configs( + device_id: int, + device_count: int, + block_seq: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, + test_count: int, +): + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) + best_config, best_cost_time = None, 10000000 + queue = mp.Queue() + test_configs = [] + for t_config in get_test_configs(device_id, device_count): + test_configs.append(t_config) + if len(test_configs) < 64: + continue + + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + while len(test_configs) != 0: + p = mp.Process( + target=worker, + args=( + block_seq, + batch_size, + seq_len, + dtype, + test_count, + test_configs, + queue, + ), + ) + p.start() + p.join() + + while len(test_configs) != 0: + try: + cost_time = queue.get_nowait() + logger.info(f"get {test_configs[0]} cost_time: {cost_time}") + if cost_time < best_cost_time: + best_config = test_configs[0] + best_cost_time = cost_time + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + except: + logger.info(f"cur best {best_config}, {best_cost_time}") + del test_configs[0:1] + break + + logger.info(f"{best_config} best cost: {best_cost_time}") + return best_config, best_cost_time + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + + from lightllm.utils.tuning_utils import mp_tuning + import collections + + block_seq = 256 + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + num_heads = 32 + kv_head_num = 8 + q_head_dim = 128 + gqa_group_size = num_heads // kv_head_num + + store_json_ans = collections.defaultdict(dict) + + for seq_len in seq_lens: + for batch_size in batch_sizes: + ans = mp_tuning( + tuning_configs, + { + "block_seq": block_seq, + "batch_size": batch_size, + "seq_len": seq_len, + "dtype": torch.bfloat16, + "test_count": 1, + }, + ) + store_json_ans[seq_len][batch_size] = ans + + GQADiverseDecodeStage2KernelConfig.save_config( + gqa_group_size=gqa_group_size, + q_head_dim=q_head_dim, + block_seq=block_seq, + out_dtype=str(torch.bfloat16), + config_json=store_json_ans, + ) diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py similarity index 85% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py index ac18ffb955..a01bbf32d8 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse.py @@ -1,7 +1,5 @@ import pytest -pytest.skip(reason="need install lightllmKernel", allow_module_level=True) - import torch from lightllm.utils.light_utils import light_ops @@ -42,31 +40,32 @@ def __init__( # @pytest.mark.parametrize("shared_seq_len", [512]) @pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550]) -def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len): +@pytest.mark.parametrize("batch_size", list(range(6, 121, 6))) +def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len, batch_size): """ - 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding + 测试 int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding 与 ppl_int8kv_flash_decoding (baseline) 的对比。 """ - from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse import ( + from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse import ( token_decode_attention_flash_decoding as diverse_attention, ) from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( token_decode_attention_flash_decoding as baseline_attention, ) - batch_size = 6 num_heads = 32 kv_head_num = 8 mark_shared_group_size = 3 - seq_len = 1024 + seq_len = 3547 head_dim = 128 quant_group_size = 8 + max_len_in_batch = 8192 test_dtype = torch.bfloat16 # 创建测试数据 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim) + kv_scale_shape = (batch_size * max_len_in_batch, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") @@ -77,7 +76,9 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le cache_v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") / 100.0 - req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + req_to_tokens = torch.arange(0, max_len_in_batch * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, max_len_in_batch + ) for i in range(batch_size): if i % mark_shared_group_size != 0: req_to_tokens[i, :shared_seq_len] = req_to_tokens[i - 1, :shared_seq_len] @@ -91,7 +92,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 baseline 的 infer_state (不需要 b_shared_seq_len) baseline_infer_state = MockInferState( batch_size=batch_size, - max_kv_seq_len=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -100,7 +101,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le # 创建 diverse 的 infer_state diverse_infer_state = MockInferState( batch_size=batch_size, - max_kv_seq_len=seq_len, + max_kv_seq_len=max_len_in_batch, req_to_tokens=req_to_tokens, b_req_idx=b_req_idx, b_seq_len=b_seq_len, @@ -129,7 +130,7 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le alloc_tensor_func=alloc_tensor_func, ) - print(f"\nshared_seq_len={shared_seq_len}") + print(f"\nshared_seq_len={shared_seq_len}\nbatch_size={batch_size}") print(f"baseline_out: {baseline_out[0, 0, :4]}") print(f"diverse_out: {diverse_out[0, 0, :4]}") print(f"max diff: {(baseline_out - diverse_out).abs().max()}") diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py similarity index 53% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py index 5ef36e38e2..f3cb8de463 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage1.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage1.py @@ -1,41 +1,48 @@ import pytest import torch -from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage1 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage1 import ( flash_decode_stage1, ) -@pytest.fixture -def setup_tensors(): - batch_size = 4 - num_heads = 4 - kv_head_num = 1 - seq_len = 256 +def create_tensors( + batch_size=4, + num_heads=4, + kv_head_num=1, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): head_dim = 128 - max_len_in_batch = seq_len block_seq = 256 - max_batch_group_size = 4 quant_group_size = 8 test_dtype = torch.bfloat16 - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") b_shared_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" ) mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" ) return { @@ -56,6 +63,11 @@ def setup_tensors(): } +@pytest.fixture +def setup_tensors(): + return create_tensors() + + def test_flash_decode_stage1_execution(setup_tensors): flash_decode_stage1( q=setup_tensors["q"], @@ -106,3 +118,71 @@ def test_flash_decode_stage1_execution(setup_tensors): assert torch.allclose( setup_tensors["mid_out_logsumexp"], true_mid_out_logsumexp, atol=1e-2 ), "LogSumExp output does not match expected values" + + +def autotune_and_benchmark(): + import triton + + batch_sizes = [8, 16, 32, 64] + seq_lens = [1024, 2048, 4096] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + batch_size=batch, + num_heads=32, + kv_head_num=8, + seq_len=seq, + max_len_in_batch=8192, + max_batch_group_size=8, + kv_len=seq, + req_to_tokens_len=seq, + ) + + def fn_triton(st=setup_tensors): + return flash_decode_stage1( + q=st["q"], + k=st["k"], + k_scale=st["k_scale"], + v=st["v"], + v_scale=st["v_scale"], + Req_to_tokens=st["Req_to_tokens"], + B_req_idx=st["B_req_idx"], + b_shared_seq_len=st["b_shared_seq_len"], + b_mark_shared_group=st["b_mark_shared_group"], + max_len_in_batch=st["max_len_in_batch"], + mid_out=st["mid_out"], + mid_out_logsumexp=st["mid_out_logsumexp"], + block_seq=st["block_seq"], + max_batch_group_size=st["max_batch_group_size"], + ) + + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12}") + print(f"{'-'*80}") + for r in results: + print(f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f}") + print(f"{'='*80}") + + +if __name__ == "__main__": + autotune_and_benchmark() diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py new file mode 100644 index 0000000000..c7d4442543 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage2.py @@ -0,0 +1,293 @@ +import pytest +import torch +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage2 import ( + flash_decode_stage2, +) + + +def create_tensors( + shared_seq_len, + batch_size=4, + seq_len=256, + max_len_in_batch=8192, + max_batch_group_size=4, + kv_len=None, + req_to_tokens_len=None, +): + num_heads = 32 + kv_head_num = 8 + head_dim = 128 + block_seq = 256 + quant_group_size = 8 + + test_dtype = torch.bfloat16 + + kv_len = max_len_in_batch if kv_len is None else kv_len + req_to_tokens_len = max_len_in_batch if req_to_tokens_len is None else req_to_tokens_len + + kv_shape = (batch_size * kv_len, kv_head_num, head_dim) + kv_scale_shape = (batch_size * kv_len, kv_head_num, head_dim // quant_group_size) + + q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") + k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") + v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") + Req_to_tokens = torch.arange(0, req_to_tokens_len * batch_size, dtype=torch.int32, device="cuda").view( + batch_size, req_to_tokens_len + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") + b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") + b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") + mid_out = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" + ) + mid_out_logsumexp = torch.zeros( + size=(batch_size, num_heads, (max_len_in_batch // block_seq) + 2), dtype=q.dtype, device="cuda" + ) + + return { + "q": q, + "k": k, + "k_scale": k_scale, + "v": v, + "v_scale": v_scale, + "Req_to_tokens": Req_to_tokens, + "B_req_idx": B_req_idx, + "b_seq_len": b_seq_len, + "b_shared_seq_len": b_shared_seq_len, + "b_mark_shared_group": b_mark_shared_group, + "max_len_in_batch": max_len_in_batch, + "mid_out": mid_out, + "mid_out_logsumexp": mid_out_logsumexp, + "block_seq": block_seq, + "max_batch_group_size": max_batch_group_size, + "head_dim": head_dim, + } + + +@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) +def test_flash_decode_stage2_execution(shared_seq_len): + setup_tensors = create_tensors(shared_seq_len) + + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ + "block_seq" + ] + mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] + mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] + + q = setup_tensors["q"] + k = setup_tensors["k"] + v = setup_tensors["v"] + true_mid_out = torch.zeros_like(mid_out) + true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) + new_q = q + new_k = k.to(q.dtype) + new_v = v.to(q.dtype) + + b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] + req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] + + from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( + flash_decode_stage1 as gqa_flash_decode_stage1, + ) + + gqa_flash_decode_stage1( + q=new_q, + k=new_k, + v=new_v, + Req_to_tokens=req_to_tokens, + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=b_seq_len, + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=true_mid_out, + mid_out_logsumexp=true_mid_out_logsumexp, + block_seq=setup_tensors["block_seq"], + ) + print(f"\nshared_seq_len={shared_seq_len}") + print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") + print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") + abs_diff = (mid_out - true_mid_out).abs() + max_diff = abs_diff.max() + max_diff_idx = abs_diff.argmax() + max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) + mid_out_value = mid_out[max_diff_idx_unraveled] + true_mid_out_value = true_mid_out[max_diff_idx_unraveled] + print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") + + assert torch.allclose( + mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 + ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" + assert torch.allclose( + mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 + ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" + + +if __name__ == "__main__": + import importlib + import triton + from lightllm.utils.light_utils import light_ops + + batch_sizes = [8, 16, 32, 64] + seq_lens = [32, 64, 128, 256] + + results = [] + for batch in batch_sizes: + for seq in seq_lens: + # Clear GPU cache to reduce CUDA Graph capture failures. + torch.cuda.empty_cache() + + setup_tensors = create_tensors( + shared_seq_len=0, + batch_size=batch, + seq_len=seq, + max_len_in_batch=8192, + kv_len=seq, + req_to_tokens_len=seq, + ) + + # Outputs for CUDA implementation + mid_out_cuda = setup_tensors["mid_out"].clone() + mid_out_logsumexp_cuda = setup_tensors["mid_out_logsumexp"].clone() + + # Outputs for Triton implementation + mid_out_triton = setup_tensors["mid_out"].clone() + mid_out_logsumexp_triton = setup_tensors["mid_out_logsumexp"].clone() + + # Run CUDA to get reference + light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + mid_out_cuda, + mid_out_logsumexp_cuda, + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + + # Run Triton + flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=mid_out_triton, + mid_out_logsumexp=mid_out_logsumexp_triton, + block_seq=setup_tensors["block_seq"], + ) + + # Compare results + diff_mid_out = torch.abs(mid_out_cuda - mid_out_triton) + diff_logsumexp = torch.abs(mid_out_logsumexp_cuda - mid_out_logsumexp_triton) + max_diff_out = diff_mid_out.max().item() + max_diff_logsumexp = diff_logsumexp.max().item() + mean_diff_out = diff_mid_out.mean().item() + mean_diff_logsumexp = diff_logsumexp.mean().item() + + cos_sim_out = torch.nn.functional.cosine_similarity( + mid_out_cuda.flatten(), mid_out_triton.flatten(), dim=0 + ).item() + cos_sim_logsumexp = torch.nn.functional.cosine_similarity( + mid_out_logsumexp_cuda.flatten(), mid_out_logsumexp_triton.flatten(), dim=0 + ).item() + + print(f"\n[batch={batch}, seq={seq}] Consistency check:") + print(" mid_out:") + print(f" max_diff: {max_diff_out:.6f}, mean_diff: {mean_diff_out:.6f}, cosine_sim: {cos_sim_out:.8f}") + print(" logsumexp:") + print( + f" max_diff: {max_diff_logsumexp:.6f}, " + f"mean_diff: {mean_diff_logsumexp:.6f}, " + f"cosine_sim: {cos_sim_logsumexp:.8f}" + ) + + # Performance + fn_cuda = lambda: light_ops.group8_int8kv_flashdecoding_diverse_stage2( + setup_tensors["block_seq"], + setup_tensors["mid_out"], + setup_tensors["mid_out_logsumexp"], + 1.0 / (setup_tensors["head_dim"] ** 0.5), + setup_tensors["q"], + setup_tensors["k"], + setup_tensors["k_scale"], + setup_tensors["v"], + setup_tensors["v_scale"], + setup_tensors["Req_to_tokens"], + setup_tensors["B_req_idx"], + setup_tensors["b_seq_len"], + setup_tensors["b_shared_seq_len"], + setup_tensors["max_len_in_batch"], + ) + ms_cuda = triton.testing.do_bench_cudagraph(fn_cuda, rep=100) + + fn_triton = lambda: flash_decode_stage2( + q=setup_tensors["q"], + k=setup_tensors["k"], + k_scale=setup_tensors["k_scale"], + v=setup_tensors["v"], + v_scale=setup_tensors["v_scale"], + Req_to_tokens=setup_tensors["Req_to_tokens"], + B_req_idx=setup_tensors["B_req_idx"], + B_Seqlen=setup_tensors["b_seq_len"], + b_shared_seq_len=setup_tensors["b_shared_seq_len"], + max_len_in_batch=setup_tensors["max_len_in_batch"], + mid_out=setup_tensors["mid_out"], + mid_out_logsumexp=setup_tensors["mid_out_logsumexp"], + block_seq=setup_tensors["block_seq"], + ) + ms_triton = triton.testing.do_bench_cudagraph(fn_triton, rep=100) + + results.append( + { + "batch_size": batch, + "seq_len": seq, + "triton_ms": ms_triton, + "cuda_ms": ms_cuda, + } + ) + print(results[-1]) + + del setup_tensors + + print(f"\n{'='*80}") + print("SUMMARY - Performance Comparison") + print(f"{'='*80}") + print(f"{'batch_size':<8} {'seq_len':<12} {'triton_ms':<12} {'cuda_ms':<12} {'vs cuda':<10}") + print(f"{'-'*80}") + for r in results: + vs_cuda = f"{r['cuda_ms']/r['triton_ms']:.2f}x" + emoji = "🎉" if r["triton_ms"] < r["cuda_ms"] else "" + print( + f"{r['batch_size']:<8} {r['seq_len']:<12} {r['triton_ms']:<12.3f} {r['cuda_ms']:<12.3f}" + f"{vs_cuda:<10} {emoji}" + ) + print(f"{'='*80}") diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py similarity index 96% rename from unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py rename to unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py index 18550982b9..c1a0ca1e58 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage3.py +++ b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_int8kv_flash_decoding_diverse_stage3.py @@ -1,6 +1,6 @@ import pytest import torch -from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding_diverse_stage3 import ( +from lightllm.common.basemodel.triton_kernel.att.decode_att.int8kv.int8kv_flash_decoding_diverse_stage3 import ( flash_diverse_decode_stage3, ) diff --git a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py b/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py deleted file mode 100644 index cde7734817..0000000000 --- a/unit_tests/common/basemodel/triton_kernel/att/decode_att/int8kv/test_ppl_int8kv_flash_decoding_diverse_stage2.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest - -pytest.skip(reason="need install lightllmkernel", allow_module_level=True) - -import torch -from lightllm.utils.light_utils import light_ops - - -def create_tensors(shared_seq_len): - batch_size = 4 - num_heads = 32 - kv_head_num = 8 - seq_len = 256 - head_dim = 128 - max_len_in_batch = seq_len - block_seq = 256 - max_batch_group_size = 4 - quant_group_size = 8 - - test_dtype = torch.bfloat16 - - kv_shape = (batch_size * seq_len, kv_head_num, head_dim) - kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size) - - q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda") - k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda") - v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") - Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len) - B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") - b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda") - b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda") - b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda") - mid_out = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda" - ) - mid_out_logsumexp = torch.zeros( - size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda" - ) - - return { - "q": q, - "k": k, - "k_scale": k_scale, - "v": v, - "v_scale": v_scale, - "Req_to_tokens": Req_to_tokens, - "B_req_idx": B_req_idx, - "b_seq_len": b_seq_len, - "b_shared_seq_len": b_shared_seq_len, - "b_mark_shared_group": b_mark_shared_group, - "max_len_in_batch": max_len_in_batch, - "mid_out": mid_out, - "mid_out_logsumexp": mid_out_logsumexp, - "block_seq": block_seq, - "max_batch_group_size": max_batch_group_size, - "head_dim": head_dim, - } - - -@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255]) -def test_flash_decode_stage2_execution(shared_seq_len): - setup_tensors = create_tensors(shared_seq_len) - - light_ops.group8_int8kv_flashdecoding_diverse_stage2( - setup_tensors["block_seq"], - setup_tensors["mid_out"], - setup_tensors["mid_out_logsumexp"], - 1.0 / (setup_tensors["head_dim"] ** 0.5), - setup_tensors["q"], - setup_tensors["k"], - setup_tensors["k_scale"], - setup_tensors["v"], - setup_tensors["v_scale"], - setup_tensors["Req_to_tokens"], - setup_tensors["B_req_idx"], - setup_tensors["b_seq_len"], - setup_tensors["b_shared_seq_len"], - setup_tensors["max_len_in_batch"], - ) - seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[ - "block_seq" - ] - mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :] - mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:] - - q = setup_tensors["q"] - k = setup_tensors["k"] - v = setup_tensors["v"] - true_mid_out = torch.zeros_like(mid_out) - true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp) - new_q = q - new_k = k.to(q.dtype) - new_v = v.to(q.dtype) - - b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"] - req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :] - - from lightllm.common.basemodel.triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding_stage1 import ( - flash_decode_stage1 as gqa_flash_decode_stage1, - ) - - gqa_flash_decode_stage1( - q=new_q, - k=new_k, - v=new_v, - Req_to_tokens=req_to_tokens, - B_req_idx=setup_tensors["B_req_idx"], - B_Seqlen=b_seq_len, - max_len_in_batch=setup_tensors["max_len_in_batch"], - mid_out=true_mid_out, - mid_out_logsumexp=true_mid_out_logsumexp, - block_seq=setup_tensors["block_seq"], - ) - print(f"\nshared_seq_len={shared_seq_len}") - print(f"mid_out: {mid_out[0:4, 0, 0, 0]}") - print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}") - abs_diff = (mid_out - true_mid_out).abs() - max_diff = abs_diff.max() - max_diff_idx = abs_diff.argmax() - max_diff_idx_unraveled = torch.unravel_index(max_diff_idx, abs_diff.shape) - mid_out_value = mid_out[max_diff_idx_unraveled] - true_mid_out_value = true_mid_out[max_diff_idx_unraveled] - print(f"max abs diff: {max_diff}, mid_out value: {mid_out_value}, " f"true_mid_out value: {true_mid_out_value}") - - assert torch.allclose( - mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2 - ), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}" - assert torch.allclose( - mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2 - ), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}" From 89e7b9ed3c826af17e9f9aa537072b448fac1da2 Mon Sep 17 00:00:00 2001 From: sangchengmeng <101796078+SangChengC@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:19:18 +0800 Subject: [PATCH 37/43] check image tag and image num (#1176) Co-authored-by: sangchengmeng --- .../int8kv/int8kv_flash_decoding_diverse_stage1.py | 2 +- lightllm/models/internvl/model.py | 8 ++++++++ lightllm/models/qwen2_vl/model.py | 4 ++++ lightllm/models/qwen2_vl/vision_process.py | 2 ++ lightllm/models/qwen_vl/model.py | 3 ++- lightllm/models/tarsier2/model.py | 4 ++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py index 4dfaffef68..295ae66ab3 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse_stage1.py @@ -291,7 +291,7 @@ def flash_decode_stage1( assert k.stride() == v.stride() NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS - + assert k.stride() == v.stride() _fwd_kernel_flash_decode_diverse_stage1[grid]( Q=q, diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 6d264a4267..ccb76d3512 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -149,6 +149,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) # audio @@ -174,6 +178,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("audio token error") except ValueError: break + if multimodal_params: + audio_cnt = len(multimodal_params.audios) + if audio_cnt != audio_id: + raise ValueError(audio_cnt == audio_id, f"invalid audio tag num: {audio_cnt} vs {audio_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index dd4181fbfb..237c4ad897 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -79,6 +79,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids) return input_ids diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 1c4f60794d..f2cd38ec8e 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -184,6 +184,8 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: return self._preprocess_bydevice(image, device="cpu") def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: + if image.mode != "RGB": + image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index d942d68497..0c6fa31f47 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -86,7 +86,8 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): input_ids.extend(origin_ids[end:]) if multimodal_params: image_cnt = len(multimodal_params.images) - assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") return input_ids diff --git a/lightllm/models/tarsier2/model.py b/lightllm/models/tarsier2/model.py index dad252b979..10a7f368c4 100644 --- a/lightllm/models/tarsier2/model.py +++ b/lightllm/models/tarsier2/model.py @@ -78,6 +78,10 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): raise ValueError("image token error") except ValueError: break + if multimodal_params: + image_cnt = len(multimodal_params.images) + if image_cnt != image_id: + raise ValueError(image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!") input_ids.extend(origin_ids[start_idx:]) return input_ids From 0777e2802bd24805f3b00204aa68ece30df54e26 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Tue, 20 Jan 2026 13:58:34 +0800 Subject: [PATCH 38/43] fix cpu kv cache offload async error (#1180) --- .../mode_backend/multi_level_kv_cache.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index d4ba902999..b9d13f512b 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -26,11 +26,16 @@ def __init__(self, backend): self.init_sync_group = create_new_group_for_current_dp("nccl") dist.barrier(group=self.init_sync_group) + self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") + self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") + self.cpu_cache_handle_queue: Deque[TransTask] = deque() self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) # 一些算子模式需要同步计算和 cpu cache 的 load 和 offload 操作 - self.need_sync_compute_stream: bool = True + self.need_sync_compute_stream: bool = ( + "fa3" in self.args.llm_decode_att_backend or "fa3" in self.args.llm_prefill_att_backend + ) def wait(self): """ @@ -89,14 +94,18 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): cpu_kv_cache_scale = None gpu_kv_cache_scale = None + mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) + page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( + non_blocking=True + ) # 将 cpu page 的内容拷贝到 gpu 页面中 load_cpu_kv_to_gpu( - gpu_mem_indexes=mem_indexes.cuda(non_blocking=True), + gpu_mem_indexes=mem_indexes_cuda, gpu_kv_cache=mem_manager.kv_buffer, gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, - page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True), + page_indexes=page_indexes_cuda, tp_index=self.backend.rank_in_dp, tp_world_size=self.backend.dp_world_size, grid_num=grid_num, @@ -221,6 +230,12 @@ def _start_kv_cache_offload_task( page_indexes = torch.tensor(page_list, dtype=torch.int32, device="cpu", pin_memory=True) page_readies = torch.tensor(ready_list, dtype=torch.bool, device="cpu", pin_memory=True) + assert len(page_indexes) <= self.page_index_buffer.shape[0] + cuda_page_indexes = self.page_index_buffer[: len(page_indexes)] + cuda_page_readies = self.page_ready_buffer[: len(page_readies)] + cuda_page_indexes.copy_(page_indexes, non_blocking=True) + cuda_page_readies.copy_(page_readies, non_blocking=True) + move_token_num = item_size * self.args.cpu_cache_token_page_size assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num] @@ -248,8 +263,8 @@ def _start_kv_cache_offload_task( gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, - page_indexes=page_indexes, - page_readies=page_readies, + page_indexes=cuda_page_indexes, + page_readies=cuda_page_readies, tp_index=self.backend.rank_in_dp, tp_world_size=self.backend.dp_world_size, grid_num=grid_num, From ac2ec929218755195a96bf4d446ba0cc3c0466f2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 22 Jan 2026 14:48:20 +0000 Subject: [PATCH 39/43] refactor fuse_moe --- lightllm/common/basemodel/basemodel.py | 6 +- .../layer_weights/meta_weights/__init__.py | 2 +- .../meta_weights/embedding_weight.py | 24 +- .../fused_moe/fused_moe_weight.py | 315 +++++++++++++++ .../fused_moe/fused_moe_weight_ep.py | 59 +-- .../fused_moe_weight_ep_redundancy.py | 4 +- .../fused_moe/fused_moe_weight_tp.py | 364 ------------------ .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 10 +- .../meta_weights/fused_moe/impl/__init__.py | 14 + .../meta_weights/fused_moe/impl/base_impl.py | 55 +++ .../fused_moe/impl/deepgemm_impl.py | 336 ++++++++++++++++ .../fused_moe/impl/marlin_impl.py | 56 +++ .../fused_moe/impl/triton_impl.py | 138 +++++++ .../meta_weights/mm_weight/mm_weight.py | 40 +- .../layer_weights/meta_weights/norm_weight.py | 17 +- .../triton_kernel}/fused_moe/__init__.py | 0 .../fused_moe/deepep_scatter_gather.py | 0 .../fused_moe/grouped_fused_moe.py | 0 .../fused_moe/grouped_fused_moe_ep.py | 8 +- .../triton_kernel}/fused_moe/grouped_topk.py | 0 .../fused_moe/moe_kernel_configs.py | 0 .../fused_moe/moe_silu_and_mul.py | 0 .../fused_moe/moe_silu_and_mul_config.py | 0 .../moe_silu_and_mul_mix_quant_ep.py | 0 .../fused_moe/moe_sum_recude_config.py | 0 .../fused_moe/moe_sum_reduce.py | 0 .../triton_kernel}/fused_moe/softmax_topk.py | 0 .../triton_kernel}/fused_moe/topk_select.py | 6 +- .../common/quantization/quantize_method.py | 4 + lightllm/distributed/communication_op.py | 4 +- .../layer_infer/transformer_layer_infer.py | 4 +- .../layer_weights/transformer_layer_weight.py | 60 +-- .../layer_weights/transformer_layer_weight.py | 5 +- .../layer_infer/transformer_layer_infer.py | 2 +- .../layer_weights/transformer_layer_weight.py | 41 +- .../layer_infer/transformer_layer_infer.py | 5 +- .../layer_weights/transformer_layer_weight.py | 48 +-- .../transformers_layer_weight.py | 1 - .../layer_weights/transformer_layer_weight.py | 1 - lightllm/server/api_cli.py | 7 +- test/start_scripts/README.md | 2 +- test/start_scripts/multi_node_ep_node0.sh | 4 +- test/start_scripts/multi_node_ep_node1.sh | 4 +- .../multi_pd_master/pd_prefill.sh | 5 +- test/start_scripts/single_node_ep.sh | 5 +- .../single_pd_master/pd_decode.sh | 3 +- .../single_pd_master/pd_nixl_decode.sh | 3 +- .../single_pd_master/pd_nixl_prefill.sh | 3 +- .../single_pd_master/pd_prefill.sh | 5 +- 49 files changed, 1066 insertions(+), 604 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/__init__.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/deepep_scatter_gather.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/grouped_fused_moe.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/grouped_fused_moe_ep.py (96%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/grouped_topk.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_kernel_configs.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_silu_and_mul.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_silu_and_mul_config.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_silu_and_mul_mix_quant_ep.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_sum_recude_config.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/moe_sum_reduce.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/softmax_topk.py (100%) rename lightllm/common/{ => basemodel/triton_kernel}/fused_moe/topk_select.py (96%) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 2dcf0c434a..e6405e4d74 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -637,8 +637,6 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod assert model_input0.mem_indexes.is_cuda assert model_input1.mem_indexes.is_cuda - input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids - infer_state0 = self._create_inferstate(model_input0, 0) init_req_to_token_indexes( req_to_token_indexs=self.req_manager.req_to_token_indexs, @@ -668,9 +666,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod prefill_mem_indexes_ready_event = torch.cuda.Event() prefill_mem_indexes_ready_event.record() - model_output0, model_output1 = self._overlap_tpsp_context_forward( - input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 - ) + model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b67f271ca4..ab0e5b6040 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -9,5 +9,5 @@ from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight -from .fused_moe.fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe.fused_moe_weight_ep import FusedMoeWeightEP +from .fused_moe.fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index 9737f41b29..d4e03d0a1f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -5,9 +5,6 @@ from .platform_op import PlatformAwareOp from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): @@ -28,7 +25,7 @@ def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch def _create_weight(self): tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) - self.load_cnt = 0 + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights: @@ -39,12 +36,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): assert ( loaded_vocab_size == self.vocab_size ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" - logger.info(f"loaded weight vocab_size: {self.vocab_size}") self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) - self.load_cnt += 1 + self.weight.load_ok = True def verify_load(self): - return self.load_cnt == 1 + return self.weight.load_ok def _native_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty @@ -114,12 +110,12 @@ def __init__( self._create_weight() def _create_weight(self): - self.load_cnt = 0 if self._embedding_weight is not None: self.weight = self._embedding_weight.weight return tp_vocab_size = self.tp_vocab_end_id - self.tp_vocab_start_id self.weight: torch.Tensor = torch.empty(tp_vocab_size, self.dim, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): # When set tile_embedding=True, no need to load - EmbeddingWeight already loaded it @@ -132,12 +128,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): assert ( loaded_vocab_size == self.vocab_size ), f"loaded weight vocab_size: {loaded_vocab_size} != expected vocab_size: {self.vocab_size}" - logger.info(f"loaded weight vocab_size: {self.vocab_size}") self.weight.copy_(t_weight[self.tp_vocab_start_id : self.tp_vocab_end_id, :].to(self.data_type_)) - self.load_cnt += 1 + self.weight.load_ok = True def verify_load(self): - return self.load_cnt == 1 or self._embedding_weight is not None + return self.weight.load_ok def _native_forward( self, input: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty @@ -181,7 +176,7 @@ def _create_weight(self): self.weight: torch.Tensor = torch.empty( self.max_position_embeddings, self.dim, dtype=self.data_type_, device=self.device_id_ ) - self.load_cnt = 0 + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name not in weights: @@ -191,12 +186,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): assert ( loaded_max_position_embeddings == self.max_position_embeddings ), f"max_position_embeddings: {loaded_max_position_embeddings} != expected: {self.max_position_embeddings}" - logger.info(f"loaded weight max_position_embeddings: {self.max_position_embeddings}") self.weight.copy_(t_weight.to(self.data_type_)) - self.load_cnt += 1 + self.weight.load_ok = True def verify_load(self): - return self.load_cnt == 1 + return self.weight.load_ok def _native_forward( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, _alloc_func=torch.empty diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py new file mode 100644 index 0000000000..8b01f46437 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -0,0 +1,315 @@ +import torch +import threading +from typing import Dict, Any, Optional, Tuple +from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( + get_row_slice_mixin, + get_col_slice_mixin, +) +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.impl import select_fuse_moe_impl +from lightllm.common.quantization.quantize_method import QuantizationMethod +from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args +from lightllm.utils.dist_utils import get_global_world_size, get_global_rank + + +class FusedMoeWeight(BaseWeightTpl): + def __init__( + self, + gate_proj_name: str, + down_proj_name: str, + up_proj_name: str, + e_score_correction_bias_name: str, + weight_prefix: str, + n_routed_experts: int, + hidden_size: int, + moe_intermediate_size: int, + data_type: torch.dtype, + quant_method: QuantizationMethod = None, + num_fused_shared_experts: int = 0, + layer_num: int = 0, + network_config: Dict[str, Any] = None, + ) -> None: + super().__init__(data_type=data_type) + self.w1_weight_name = gate_proj_name + self.w2_weight_name = down_proj_name + self.w3_weight_name = up_proj_name + self.e_score_correction_bias_name = e_score_correction_bias_name + self.weight_prefix = weight_prefix + self.layer_num_ = layer_num + self.global_rank_ = get_global_rank() + self.global_world_size = get_global_world_size() + self.hidden_size = hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.quant_method = quant_method + self.row_slicer = get_row_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) + self.col_slicer = get_col_slice_mixin( + self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ + ) + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.enable_ep_moe = get_env_start_args().enable_ep_moe + self.n_routed_experts = n_routed_experts + self.num_fused_shared_experts = num_fused_shared_experts + self._init_config(network_config) + self._init_parallel_params() + self.fuse_moe_impl = select_fuse_moe_impl(self.quant_method, self.enable_ep_moe)( + n_routed_experts=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + redundancy_expert_num=self.redundancy_expert_num, + routed_scaling_factor=self.routed_scaling_factor, + quant_method=self.quant_method, + ) + self.lock = threading.Lock() + self._create_weight() + + def _init_config(self, network_config: Dict[str, Any]): + self.n_group = network_config.get("n_group", 0) + self.use_grouped_topk = self.n_group > 0 + self.norm_topk_prob = network_config["norm_topk_prob"] + self.topk_group = network_config.get("topk_group", 0) + self.num_experts_per_tok = network_config["num_experts_per_tok"] + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) + self.scoring_func = network_config.get("scoring_func", "softmax") + + def _init_parallel_params(self): + self.local_n_routed_experts = self.n_routed_experts + self.num_fused_shared_experts + self.start_expert_id = 0 + self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ + self.redundancy_expert_num = 0 + if self.enable_ep_moe: + assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" + self.redundancy_expert_num = get_redundancy_expert_num() + self.redundancy_expert_ids = get_redundancy_expert_ids(self.layer_num_) + self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num + self.start_expert_id = self.global_rank_ * self.n_routed_experts // self.global_world_size + self.split_inter_size = self.moe_intermediate_size + + def experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ): + """Backward compatible method that routes to platform-specific implementation.""" + return self.fuse_moe_impl( + input_tensor=input_tensor, + router_logits=router_logits, + w13=self.w13, + w2=self.w2, + correction_bias=self.e_score_correction_bias, + scoring_func=self.scoring_func, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + is_prefill=is_prefill, + ) + + def low_latency_dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ): + assert self.enable_ep_moe, "low_latency_dispatch is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.low_latency_dispatch( + hidden_states=hidden_states, + router_logits=router_logits, + e_score_correction_bias=self.e_score_correction_bias, + use_grouped_topk=self.use_grouped_topk, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + topk_group=self.topk_group, + n_group=self.n_group, + scoring_func=self.scoring_func, + ) + + def select_experts_and_quant_input( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ): + assert self.enable_ep_moe, "select_experts_and_quant_input is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.select_experts_and_quant_input( + hidden_states=hidden_states, + router_logits=router_logits, + e_score_correction_bias=self.e_score_correction_bias, + w13=self.w13, + use_grouped_topk=self.use_grouped_topk, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + topk_group=self.topk_group, + n_group=self.n_group, + scoring_func=self.scoring_func, + ) + + def dispatch( + self, + qinput_tensor: Tuple[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_event: Optional[Any] = None, + ): + assert self.enable_ep_moe, "dispatch is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.dispatch( + qinput_tensor=qinput_tensor, + topk_idx=topk_idx, + topk_weights=topk_weights, + overlap_event=overlap_event, + ) + + def masked_group_gemm( + self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int + ): + assert self.enable_ep_moe, "masked_group_gemm is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.masked_group_gemm( + recv_x=recv_x, + w13=self.w13, + w2=self.w2, + masked_m=masked_m, + dtype=dtype, + expected_m=expected_m, + ) + + def prefilled_group_gemm( + self, + num_recv_tokens_per_expert_list, + recv_x: Tuple[torch.Tensor], + recv_topk_idx: torch.Tensor, + recv_topk_weights: torch.Tensor, + hidden_dtype=torch.bfloat16, + ): + assert self.enable_ep_moe, "prefilled_group_gemm is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.prefilled_group_gemm( + num_recv_tokens_per_expert_list=num_recv_tokens_per_expert_list, + recv_x=recv_x, + recv_topk_idx=recv_topk_idx, + recv_topk_weights=recv_topk_weights, + w13=self.w13, + w2=self.w2, + hidden_dtype=hidden_dtype, + ) + + def low_latency_combine( + self, + gemm_out_b: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Any, + ): + assert self.enable_ep_moe, "low_latency_combine is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.low_latency_combine( + gemm_out_b=gemm_out_b, + topk_idx=topk_idx, + topk_weights=topk_weights, + handle=handle, + ) + + def combine( + self, + gemm_out_b: torch.Tensor, + handle: Any, + overlap_event: Optional[Any] = None, + ): + assert self.enable_ep_moe, "combine is only supported when enable_ep_moe is True" + return self.fuse_moe_impl.combine( + gemm_out_b=gemm_out_b, + handle=handle, + overlap_event=overlap_event, + ) + + def load_hf_weights(self, weights): + # Load bias + if self.e_score_correction_bias_name in weights: + self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + + # Load each expert with TP slicing + for i_experts in range(self.start_expert_id, self.start_expert_id + self.local_n_routed_experts): + with self.lock: + self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) + if self.w13.weight_scale is not None: + with self.lock: + self._load_expert( + i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix + ) + if self.w13.weight_zero_point is not None: + with self.lock: + self._load_expert( + i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix + ) + + def verify_load(self): + return True + return self.load_cnt == self.n_routed_experts * 3 * 2 + + def _create_weight(self): + intermediate_size = self.split_inter_size + self.e_score_correction_bias = None + # Create e_score_correction_bias + if self.e_score_correction_bias_name: + self.e_score_correction_bias = torch.empty( + (self.n_routed_experts,), + dtype=self.data_type_, + device=f"cuda:{self.device_id_}", + ) + + self.w13: WeightPack = self.quant_method.create_weight( + out_dim=intermediate_size * 2, + in_dim=self.hidden_size, + dtype=self.data_type_, + device_id=self.device_id_, + num_experts=self.local_n_routed_experts, + ) + self.w2: WeightPack = self.quant_method.create_weight( + out_dim=self.hidden_size, + in_dim=intermediate_size, + dtype=self.data_type_, + device_id=self.device_id_, + num_experts=self.local_n_routed_experts, + ) + self.load_cnt = 0 + + def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, weight_pack, start_idx) + else: + self.quant_method.load_weight(weight, weight_pack, start_idx) + + def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): + w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" + w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" + w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" + intermediate_size = self.split_inter_size + load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) + local_expert_idx = expert_idx - self.start_expert_id + if w1_weight in weights: + load_func(slice_func(weights[w1_weight]), self.w13.get_expert(local_expert_idx), start_idx=0) + self.load_cnt += 1 + if w3_weight in weights: + load_func( + slice_func(weights[w3_weight]), self.w13.get_expert(local_expert_idx), start_idx=intermediate_size + ) + self.load_cnt += 1 + load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) + if w2_weight in weights: + load_func(slice_func(weights[w2_weight]), self.w2.get_expert(local_expert_idx), start_idx=0) + self.load_cnt += 1 + + def _get_load_and_slice_func(self, type: str, is_row: bool = True): + if is_row: + slicer = self.row_slicer + else: + slicer = self.col_slicer + if type == "weight": + return self._load_weight_func, slicer._slice_weight + elif type == "weight_scale": + return getattr(self.quant_method, "load_weight_scale"), slicer._slice_weight_scale + elif type == "weight_zero_point": + return getattr(self.quant_method, "load_weight_zero_point"), slicer._slice_weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py index 342026de21..6659a98d4e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py @@ -4,14 +4,14 @@ from lightllm.utils.dist_utils import get_global_world_size, get_global_rank from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.fused_moe.grouped_fused_moe_ep import ( +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( fused_experts_impl, masked_group_gemm, _deepgemm_grouped_fp8_nt_contiguous, ) -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.distributed import dist_group_manager -from lightllm.common.fused_moe.topk_select import select_experts +from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num from lightllm.utils.envs_utils import get_env_start_args @@ -19,7 +19,7 @@ per_token_group_quant_fp8, tma_align_input_scale, ) -from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather +from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner @@ -185,57 +185,6 @@ def _select_experts( ) return topk_weights, topk_ids - def _native_forward( - self, - input_tensor, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - is_prefill, - ): - """PyTorch native implementation for EP MoE forward pass.""" - topk_weights, topk_ids = self._select_experts( - input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - - # Native PyTorch implementation (less optimized but works on all platforms) - batch_size, hidden_size = input_tensor.shape - intermediate_size = w1.shape[1] // 2 - - output = torch.zeros_like(input_tensor) - - for i in range(batch_size): - expert_output = torch.zeros(hidden_size, dtype=input_tensor.dtype, device=input_tensor.device) - for j in range(top_k): - expert_idx = topk_ids[i, j].item() - weight = topk_weights[i, j] - - # Get local expert index (EP mode uses local expert indices) - local_expert_idx = expert_idx % self.ep_load_expert_num - - # Get expert weights - w1_expert = w1[local_expert_idx, :intermediate_size, :] # gate - w3_expert = w1[local_expert_idx, intermediate_size:, :] # up - w2_expert = w2[local_expert_idx] - - # Compute: SiLU(x @ w1.T) * (x @ w3.T) @ w2.T - x = input_tensor[i : i + 1] - gate = torch.nn.functional.silu(torch.mm(x, w1_expert.T)) - up = torch.mm(x, w3_expert.T) - hidden = gate * up - expert_out = torch.mm(hidden, w2_expert.T) - expert_output += weight * expert_out.squeeze(0) - - output[i] = expert_output - - return output - def _cuda_forward( self, input_tensor, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py index 933a94f78c..a31cd18803 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py @@ -1,6 +1,6 @@ import numpy as np import torch -from .fused_moe_weight_ep import FusedMoeWeightEP +from .fused_moe_weight import FusedMoeWeight from lightllm.utils.log_utils import init_logger from typing import Dict @@ -10,7 +10,7 @@ class FusedMoeWeightEPAutoRedundancy: def __init__( self, - ep_fused_moe_weight: FusedMoeWeightEP, + ep_fused_moe_weight: FusedMoeWeight, ) -> None: super().__init__() self._ep_w = ep_fused_moe_weight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py deleted file mode 100644 index c7892ab3ba..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_tp.py +++ /dev/null @@ -1,364 +0,0 @@ -import torch -from typing import Dict, Any, Union -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.quantization import Quantcfg -from lightllm.common.quantization.quantize_method import WeightPack -from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( - get_row_slice_mixin, - get_col_slice_mixin, -) -import threading - - -def create_tp_moe_wegiht_obj( - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, -) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]: - quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - if quant_method is not None and quant_method.method_name == "awq_marlin": - return FusedAWQMARLINMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - else: - return FusedMoeWeightTP( - gate_proj_name=gate_proj_name, - down_proj_name=down_proj_name, - up_proj_name=up_proj_name, - e_score_correction_bias_name=e_score_correction_bias_name, - weight_prefix=weight_prefix, - n_routed_experts=n_routed_experts, - num_fused_shared_experts=num_fused_shared_experts, - split_inter_size=split_inter_size, - data_type=data_type, - network_config=network_config, - layer_num=layer_num, - quant_cfg=quant_cfg, - ) - - -class FusedMoeWeightTP(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - num_fused_shared_experts: int, - split_inter_size: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg: Quantcfg = None, - ) -> None: - super().__init__() - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method.method_name != "none": - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - - self.e_score_correction_bias_name = e_score_correction_bias_name - self.weight_prefix = weight_prefix - assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." - self.n_routed_experts = n_routed_experts + num_fused_shared_experts - self.num_fused_shared_experts = num_fused_shared_experts - self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) - self.split_inter_size = split_inter_size - self.data_type_ = data_type - self.hidden_size = network_config.get("hidden_size") - self.e_score_correction_bias = None - self.scoring_func = network_config.get("scoring_func", "softmax") - self.row_slicer = get_row_slice_mixin( - self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ - ) - self.col_slicer = get_col_slice_mixin( - self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_ - ) - self.lock = threading.Lock() - self._create_weight() - - def _create_weight(self): - total_expert_num = self.n_routed_experts - intermediate_size = self.split_inter_size - - # Create e_score_correction_bias - if self.e_score_correction_bias_name is not None: - self.e_score_correction_bias = torch.empty( - (total_expert_num,), - dtype=self.data_type_, - device=f"cuda:{self.device_id_}", - ) - - self.w13: WeightPack = self.quant_method.create_weight( - out_dim=intermediate_size * 2, - in_dim=self.hidden_size, - dtype=self.data_type_, - device_id=self.device_id_, - num_experts=total_expert_num, - ) - self.w2: WeightPack = self.quant_method.create_weight( - out_dim=self.hidden_size, - in_dim=intermediate_size, - dtype=self.data_type_, - device_id=self.device_id_, - num_experts=total_expert_num, - ) - self.load_cnt = 0 - - def _select_experts( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """Select experts and return topk weights and ids.""" - from lightllm.common.fused_moe.topk_select import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.num_fused_shared_experts > 0: - pad_topk_ids = ( - torch.arange( - start=self.n_routed_experts - self.num_fused_shared_experts, - end=self.n_routed_experts, - step=1, - dtype=topk_ids.dtype, - device="cuda", - ) - .view(1, self.num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - pad_topk_weights = torch.full( - (topk_weights.shape[0], self.num_fused_shared_experts), - fill_value=1.0, - device="cuda", - dtype=topk_weights.dtype, - ) - - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) - return topk_weights, topk_ids - - def _native_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - topk_weights, topk_ids = self._select_experts( - input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ) - - w13, _ = self.w13.weight, self.w13.weight_scale - w2, _ = self.w2.weight, self.w2.weight_scale - - batch_size, hidden_size = input_tensor.shape - intermediate_size = w13.shape[1] // 2 - - output = torch.zeros_like(input_tensor) - - for i in range(batch_size): - expert_output = torch.zeros(hidden_size, dtype=input_tensor.dtype, device=input_tensor.device) - for j in range(top_k): - expert_idx = topk_ids[i, j].item() - weight = topk_weights[i, j] - - w1 = w13[expert_idx, :intermediate_size, :] # gate - w3 = w13[expert_idx, intermediate_size:, :] # up - w2_expert = w2[expert_idx] - - # Compute: SiLU(x @ w1.T) * (x @ w3.T) @ w2.T - x = input_tensor[i : i + 1] - gate = torch.nn.functional.silu(torch.mm(x, w1.T)) - up = torch.mm(x, w3.T) - hidden = gate * up - expert_out = torch.mm(hidden, w2_expert.T) - expert_output += weight * expert_out.squeeze(0) - - output[i] = expert_output - - input_tensor.copy_(output) - return - - def _cuda_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """CUDA optimized implementation of MoE forward pass.""" - topk_weights, topk_ids = self._select_experts( - input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ) - - w13, w13_scale = self.w13.weight, self.w13.weight_scale - w2, w2_scale = self.w2.weight, self.w2.weight_scale - use_fp8_w8a8 = self.quant_method.method_name != "none" - - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts - - fused_experts( - hidden_states=input_tensor, - w1=w13, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=use_fp8_w8a8, - w1_scale=w13_scale, - w2_scale=w2_scale, - ) - return - - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - """Backward compatible method that routes to platform-specific implementation.""" - return self._forward( - input_tensor=input_tensor, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - ) - - def load_hf_weights(self, weights): - # Load bias - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) - - # Load each expert with TP slicing - for i_experts in range(self.n_routed_experts): - with self.lock: - self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) - if self.w13.weight_scale is not None: - with self.lock: - self._load_expert( - i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix - ) - if self.w13.weight_zero_point is not None: - with self.lock: - self._load_expert( - i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix - ) - - def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): - if self.quant_method.weight_need_quanted(weight): - self.quant_method.quantize(weight, weight_pack, start_idx) - else: - self.quant_method.load_weight(weight, weight_pack, start_idx) - - def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): - w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" - w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" - w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" - intermediate_size = self.split_inter_size - load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) - if w1_weight in weights: - load_func(slice_func(weights[w1_weight]), self.w13.get_expert(expert_idx), start_idx=0) - self.load_cnt += 1 - if w3_weight in weights: - load_func(slice_func(weights[w3_weight]), self.w13.get_expert(expert_idx), start_idx=intermediate_size) - self.load_cnt += 1 - load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) - if w2_weight in weights: - load_func(slice_func(weights[w2_weight]), self.w2.get_expert(expert_idx), start_idx=0) - self.load_cnt += 1 - - def verify_load(self): - return self.load_cnt == self.n_routed_experts * 3 * 2 - - def _get_load_and_slice_func(self, type: str, is_row: bool = True): - if is_row: - slicer = self.row_slicer - else: - slicer = self.col_slicer - if type == "weight": - return self._load_weight_func, slicer._slice_weight - elif type == "weight_scale": - return getattr(self.quant_method, "load_weight_scale"), slicer._slice_weight_scale - elif type == "weight_zero_point": - return getattr(self.quant_method, "load_weight_zero_point"), slicer._slice_weight_zero_point - - -class FusedAWQMARLINMoeWeightTP(FusedMoeWeightTP): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops - - assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" - from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, - ) - - self.workspace = marlin_make_workspace_new(self.w13.weight.device, 4) - - def _native_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """AWQ Marlin quantization requires CUDA, native forward not supported.""" - raise NotImplementedError("AWQ Marlin MoE requires CUDA platform, native forward not supported.") - - def _cuda_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """CUDA optimized implementation using AWQ Marlin kernels.""" - topk_weights, topk_ids = self._select_experts( - input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ) - - w1, w1_scale, w1_zero_point = self.w13.weight, self.w13.weight_scale, self.w13.weight_zero_point - w2, w2_scale, w2_zero_point = self.w2.weight, self.w2.weight_scale, self.w2.weight_zero_point - - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe - - fused_marlin_moe( - input_tensor, - w1, - w2, - None, - None, - w1_scale, - w2_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=self.quant_method.vllm_quant_type.id, - apply_router_weight_on_input=False, - global_num_experts=-1, - expert_map=None, - w1_zeros=w1_zero_point, - w2_zeros=w2_zero_point, - workspace=self.workspace, - inplace=True, - ) - - return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 9821b5ad66..f3f153b0ab 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -1,9 +1,9 @@ import torch from typing import Dict, Any -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_tp import FusedMoeWeightTP -from lightllm.common.quantization import Quantcfg +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight from lightllm.utils.log_utils import init_logger +from lightllm.common.quantization.quantize_method import QuantizationMethod logger = init_logger(__name__) @@ -27,7 +27,7 @@ ] -class GPTOSSFusedMoeWeightTP(FusedMoeWeightTP): +class GPTOSSFusedMoeWeightTP(FusedMoeWeight): def __init__( self, gate_up_proj_name: str, # diff with FusedMoeWeightTP @@ -41,7 +41,7 @@ def __init__( network_config: Dict[str, Any], layer_num: int, world_size: int = 1, # diff with FusedMoeWeightTP - quant_cfg: Quantcfg = None, + quant_method: QuantizationMethod = None, ) -> None: super().__init__( gate_up_proj_name, @@ -55,7 +55,7 @@ def __init__( data_type, network_config, layer_num, - quant_cfg, + quant_method, ) self.hidden_size = network_config["hidden_size"] diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py new file mode 100644 index 0000000000..67bb90e4ef --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/__init__.py @@ -0,0 +1,14 @@ +from lightllm.common.quantization.quantize_method import QuantizationMethod +from .triton_impl import FuseMoeTriton +from .marlin_impl import FuseMoeMarlin +from .deepgemm_impl import FuseMoeDeepGEMM + + +def select_fuse_moe_impl(quant_method: QuantizationMethod, enable_ep_moe: bool): + if enable_ep_moe: + return FuseMoeDeepGEMM + + if quant_method.method_name == "awq_marlin": + return FuseMoeMarlin + else: + return FuseMoeTriton diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py new file mode 100644 index 0000000000..2f5d169eb1 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -0,0 +1,55 @@ +import torch +from abc import abstractmethod +from lightllm.common.quantization.quantize_method import ( + WeightPack, + QuantizationMethod, +) +from typing import Optional +from lightllm.utils.dist_utils import ( + get_global_rank, + get_global_world_size, +) + + +class FuseMoeBaseImpl: + def __init__( + self, + n_routed_experts: int, + num_fused_shared_experts: int, + redundancy_expert_num: int, + routed_scaling_factor: float, + quant_method: QuantizationMethod, + ): + self.n_routed_experts = n_routed_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.redundancy_expert_num = redundancy_expert_num + self.routed_scaling_factor = routed_scaling_factor + self.quant_method = quant_method + self.global_rank_ = get_global_rank() + self.global_world_size = get_global_world_size() + self.total_expert_num_contain_redundancy = ( + self.n_routed_experts + self.redundancy_expert_num * self.global_world_size + ) + self.workspace = self.create_workspace() + + @abstractmethod + def create_workspace(self): + pass + + @abstractmethod + def __call__( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ): + pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py new file mode 100644 index 0000000000..f00d572d9d --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -0,0 +1,336 @@ +import torch +from typing import Optional, Tuple, Any +from .triton_impl import FuseMoeTriton +from lightllm.distributed import dist_group_manager +from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.quantization.quantize_method import WeightPack +from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank +from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( + fused_experts_impl, + masked_group_gemm, + _deepgemm_grouped_fp8_nt_contiguous, +) +from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( + per_token_group_quant_fp8, + tma_align_input_scale, +) +from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair + + +class FuseMoeDeepGEMM(FuseMoeTriton): + def _select_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + correction_bias: Optional[torch.Tensor], + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + scoring_func: str, + ): + """Select experts and return topk weights and ids.""" + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.redundancy_expert_num > 0: + redundancy_topk_ids_repair( + topk_ids=topk_ids, + redundancy_expert_ids=self.redundancy_expert_ids_tensor, + ep_expert_num=self.ep_n_routed_experts, + global_rank=self.global_rank_, + expert_counter=self.routed_expert_counter_tensor, + enable_counter=self.auto_update_redundancy_expert, + ) + return topk_weights, topk_ids + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + ): + + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = self.quant_method.method_name != "none" + output = fused_experts_impl( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_idx=topk_ids.to(torch.long), + num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy + buffer=dist_group_manager.ep_buffer, + is_prefill=is_prefill, + use_fp8_w8a8=use_fp8_w8a8, + use_fp8_all2all=use_fp8_w8a8, + use_int8_w8a16=False, # default to False + w1_scale=w13_scale, + w2_scale=w2_scale, + previous_event=None, # for overlap + ) + return output + + def low_latency_dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + e_score_correction_bias: torch.Tensor, + use_grouped_topk: bool, + num_experts_per_tok: int, + norm_topk_prob: bool, + topk_group: int, + n_group: int, + scoring_func: str, + ): + topk_weights, topk_idx = self._select_experts( + input_tensor=hidden_states, + router_logits=router_logits, + correction_bias=e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=num_experts_per_tok, + renormalize=norm_topk_prob, + topk_group=topk_group, + num_expert_group=n_group, + scoring_func=scoring_func, + ) + + topk_idx = topk_idx.to(torch.long) + num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() + use_fp8_w8a8 = self.quant_method.method_name != "none" + recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + self.total_expert_num_contain_redundancy, + use_fp8=use_fp8_w8a8, + async_finish=False, + return_recv_hook=True, + ) + return recv_x, masked_m, topk_idx, topk_weights, handle, hook + + def select_experts_and_quant_input( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + e_score_correction_bias: torch.Tensor, + w13: WeightPack, + use_grouped_topk: bool, + num_experts_per_tok: int, + norm_topk_prob: bool, + topk_group: int, + n_group: int, + scoring_func: str, + ): + topk_weights, topk_idx = self._select_experts( + input_tensor=hidden_states, + router_logits=router_logits, + correction_bias=e_score_correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=num_experts_per_tok, + renormalize=norm_topk_prob, + topk_group=topk_group, + num_expert_group=n_group, + scoring_func=scoring_func, + ) + w13_weight, w13_scale = w13.weight, w13.weight_scale + block_size_k = 0 + if w13_weight.ndim == 3: + block_size_k = w13_weight.shape[2] // w13_scale.shape[2] + assert block_size_k == 128, "block_size_k must be 128" + qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w13_weight.dtype) + return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + + def dispatch( + self, + qinput_tensor: Tuple[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_event: Optional[Any] = None, + ): + buffer = dist_group_manager.ep_buffer + # get_dispatch_layout + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event, + ) = buffer.get_dispatch_layout( + topk_idx, + self.total_expert_num_contain_redundancy, + previous_event=overlap_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( + qinput_tensor, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + expert_alignment=128, + ) + + def hook(): + event.current_stream_wait() + + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + + def masked_group_gemm( + self, + recv_x: Tuple[torch.Tensor], + w13: WeightPack, + w2: WeightPack, + masked_m: torch.Tensor, + dtype: torch.dtype, + expected_m: int, + ): + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + return masked_group_gemm( + recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m + ) + + def prefilled_group_gemm( + self, + num_recv_tokens_per_expert_list, + recv_x: Tuple[torch.Tensor], + recv_topk_idx: torch.Tensor, + recv_topk_weights: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + hidden_dtype=torch.bfloat16, + ): + device = recv_x[0].device + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + _, K = recv_x[0].shape + _, N, _ = w13_weight.shape + block_size = self.quant_method.block_size + # scatter + all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. + # gather_out shape [recive_num_tokens, hidden] + gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype) + if all_tokens > 0: + input_tensor = [ + torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype), + torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32), + ] + # when m_indices is filled ok. + # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] + # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] + # ... + m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) + # output_index shape [recive_num_tokens, topk_num] + # output_index use to show the token index in input_tensor + output_index = torch.empty_like(recv_topk_idx) + + num_recv_tokens_per_expert = torch.tensor( + num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + ).cuda(non_blocking=True) + + expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) + + ep_scatter( + recv_x[0], + recv_x[1], + recv_topk_idx, + num_recv_tokens_per_expert, + expert_start_loc, + input_tensor[0], + input_tensor[1], + m_indices, + output_index, + ) + input_tensor[1] = tma_align_input_scale(input_tensor[1]) + # groupgemm (contiguous layout) + gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) + + _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w13_weight, w13_scale), gemm_out_a, m_indices) + + # silu_and_mul_fwd + qaunt + # TODO fused kernel + silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) + + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( + silu_out, block_size, dtype=w13_weight.dtype, column_major_scales=True, scale_tma_aligned=True + ) + + # groupgemm (contiguous layout) + gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) + + _deepgemm_grouped_fp8_nt_contiguous( + (qsilu_out, qsilu_out_scale), (w2_weight, w2_scale), gemm_out_b, m_indices + ) + # gather and local reduce + ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) + else: + ######################################## warning ################################################## + # here is used to match autotune feature, make moe model run same triton kernel in different rank. + # in some special case, one rank will recv 0 token, so add a token to make it run triton kernel. + if Autotuner.is_autotune_warmup(): + _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) + _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + _gemm_out_a, _silu_out = None, None + + return gather_out + + def low_latency_combine( + self, + gemm_out_b: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Any, + ): + combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( + gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True + ) + return combined_x, hook + + def combine( + self, + gemm_out_b: torch.Tensor, + handle: Any, + overlap_event: Optional[Any] = None, + ): + # normal combine + combined_x, _, event = dist_group_manager.ep_buffer.combine( + gemm_out_b, + handle, + topk_weights=None, + async_finish=True, + previous_event=overlap_event, + allocate_on_comm_stream=True, + ) + + def hook(): + event.current_stream_wait() + + return combined_x, hook diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py new file mode 100644 index 0000000000..bdccbc0eef --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -0,0 +1,56 @@ +import torch +from .triton_impl import FuseMoeTriton +from lightllm.common.quantization.quantize_method import ( + WeightPack, +) +from typing import Optional + + +class FuseMoeMarlin(FuseMoeTriton): + def create_workspace(self): + from lightllm.utils.vllm_utils import HAS_VLLM + + assert HAS_VLLM, "moe awq marlin quantization requires kernels of vllm" + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, + ) + + return marlin_make_workspace_new(torch.device("cuda"), 4) + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: Optional[bool] = None, + ): + + w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point + w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe + + fused_marlin_moe( + input_tensor, + w1_weight, + w2_weight, + None, + None, + w1_scale, + w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_method.vllm_quant_type.id, + apply_router_weight_on_input=False, + global_num_experts=-1, + expert_map=None, + w1_zeros=w1_zero_point, + w2_zeros=w2_zero_point, + workspace=self.workspace, + inplace=True, + ) + return input_tensor diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py new file mode 100644 index 0000000000..9965246a23 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -0,0 +1,138 @@ +import torch +from typing import Optional +from lightllm.common.quantization.no_quant import WeightPack +from lightllm.common.quantization.quantize_method import QuantizationMethod +from .base_impl import FuseMoeBaseImpl + + +class FuseMoeTriton(FuseMoeBaseImpl): + def __init__( + self, + n_routed_experts: int, + num_fused_shared_experts: int, + redundancy_expert_num: int, + routed_scaling_factor: float, + quant_method: QuantizationMethod, + ): + super().__init__( + n_routed_experts, num_fused_shared_experts, redundancy_expert_num, routed_scaling_factor, quant_method + ) + + def create_workspace(self): + return None + + def _select_experts( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + correction_bias: Optional[torch.Tensor], + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + scoring_func: str, + ): + """Select experts and return topk weights and ids.""" + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts + + topk_weights, topk_ids = select_experts( + hidden_states=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + topk_weights.mul_(self.routed_scaling_factor) + if self.num_fused_shared_experts > 0: + pad_topk_ids = ( + torch.arange( + start=self.n_routed_experts, + end=self.n_routed_experts + self.num_fused_shared_experts, + step=1, + dtype=topk_ids.dtype, + device="cuda", + ) + .view(1, self.num_fused_shared_experts) + .repeat(topk_ids.shape[0], 1) + ) + pad_topk_weights = torch.full( + (topk_weights.shape[0], self.num_fused_shared_experts), + fill_value=1.0, + device="cuda", + dtype=topk_weights.dtype, + ) + + topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_topk_weights], dim=1) + return topk_weights, topk_ids + + def _fused_experts( + self, + input_tensor: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + router_logits: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ): + w13_weight, w13_scale = w13.weight, w13.weight_scale + w2_weight, w2_scale = w2.weight, w2.weight_scale + use_fp8_w8a8 = w13_weight.dtype == torch.float8_e4m3fn + + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe import fused_experts + + fused_experts( + hidden_states=input_tensor, + w1=w13_weight, + w2=w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=use_fp8_w8a8, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + return input_tensor + + def __call__( + self, + input_tensor: torch.Tensor, + router_logits: torch.Tensor, + w13: WeightPack, + w2: WeightPack, + correction_bias: Optional[torch.Tensor], + scoring_func: str, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: int, + num_expert_group: int, + is_prefill: Optional[bool] = None, + ): + topk_weights, topk_ids = self._select_experts( + input_tensor=input_tensor, + router_logits=router_logits, + correction_bias=correction_bias, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + ) + output = self._fused_experts( + input_tensor=input_tensor, + w13=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=router_logits, + is_prefill=is_prefill, + ) + return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 3ba4d3e592..1133e4d6ab 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -65,7 +65,6 @@ def mm( def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod]): if quant_method is None: - self.quanted_weight_names = None self.weight_zero_point_names = None self.weight_scale_names = None return @@ -86,9 +85,7 @@ def gen_weight_quant_param_names(self, quant_method: Optional[QuantizationMethod quanted_weight_names.append(weight_name) if len(quanted_weight_names) != 0: - self.quanted_weight_names = quanted_weight_names - else: - self.quanted_weight_names = None + self.weight_names = quanted_weight_names if len(weight_scale_names) != 0: self.weight_scale_names = weight_scale_names @@ -106,10 +103,6 @@ def load_hf_weights(self, weights): for sub_child_index, param_name in enumerate(self.weight_names): self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - if self.quanted_weight_names is not None: - for sub_child_index, param_name in enumerate(self.quanted_weight_names): - self._load_weight(param_name=param_name, weights=weights, sub_child_index=sub_child_index) - if self.bias_names is not None: for sub_child_index, param_name in enumerate(self.bias_names): self._load_bias(param_name=param_name, weights=weights, sub_child_index=sub_child_index) @@ -124,10 +117,11 @@ def _create_weight(self): self.bias = None if self.bias_names is not None: self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.bias._load_ok = [False] * len(self.bias_names) self.mm_param: WeightPack = self.quant_method.create_weight( in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id() ) - self.load_cnt = 0 + self.mm_param.initialize_load_status(len(self.weight_names)) return # 执行顺序 @@ -139,9 +133,13 @@ def _load_weight( start_idx = self.cusum_out_dims[sub_child_index] if self.quant_method.weight_need_quanted(weight): self.quant_method.quantize(weight, self.mm_param, offset=start_idx) + # weight_scale and zero_point will be computed during online quantization. + # so we set them to True here. + self.mm_param.load_ok[sub_child_index][1] = True + self.mm_param.load_ok[sub_child_index][2] = True else: self.quant_method.load_weight(weight, self.mm_param, start_idx) - self.load_cnt += 1 + self.mm_param.load_ok[sub_child_index][0] = True return def _load_bias( @@ -151,7 +149,8 @@ def _load_bias( bias = self.param_slicer._slice_bias(weights[param_name]) start_idx = self.cusum_out_dims[sub_child_index] end_idx = start_idx + bias.shape[0] - self.mm_param.bias[start_idx:end_idx].copy_(bias) + self.bias[start_idx:end_idx].copy_(bias) + self.bias._load_ok[sub_child_index] = True return def _load_weight_scale( @@ -161,7 +160,7 @@ def _load_weight_scale( weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) start_idx = self.cusum_out_dims[sub_child_index] self.quant_method.load_weight_scale(weight_scale, self.mm_param, start_idx) - self.load_cnt += 1 + self.mm_param.load_ok[sub_child_index][1] = True return def _load_weight_zero_point( @@ -171,14 +170,15 @@ def _load_weight_zero_point( weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) start_idx = self.cusum_out_dims[sub_child_index] self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param, start_idx) - self.load_cnt += 1 + self.mm_param.load_ok[sub_child_index][2] = True return def verify_load(self): - if self.quant_method.method_name != "none": - return self.load_cnt == len(self.weight_names) * 2 - else: - return self.load_cnt == len(self.weight_names) + mm_param_load_ok = all(all(load_ok_list) for load_ok_list in self.mm_param.load_ok) + bias_load_ok = True if self.bias is None else all(self.bias._load_ok) + if not (mm_param_load_ok and bias_load_ok): + logger.warning(f"mm_param_load_ok: {self.mm_param.load_ok}, bias_load_ok: {self.bias}") + return mm_param_load_ok and bias_load_ok def _get_tp_dim(self, dim: int) -> int: assert ( @@ -218,7 +218,7 @@ def __init__( def _create_weight(self): self.weight = torch.empty(self.dim0, self.dim1, self.dim2, dtype=self.data_type_).cuda(get_current_device_id()) - self.load_cnt = 0 + self.weight._load_ok = False return def load_hf_weights(self, weights: Dict[str, torch.Tensor]): @@ -226,11 +226,11 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if weight_name in weights: weight = self.param_slicer._slice_weight(weights[weight_name]) self.weight.copy_(weight) - self.load_cnt += 1 + self.weight._load_ok = True return def verify_load(self): - return self.load_cnt == len(self.weight_names) + return self.weight._load_ok def bmm( self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 1a8f59723b..d4717386b3 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -19,15 +19,15 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - self.load_cnt = 0 + self.weight.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) - self.load_cnt += 1 + self.weight.load_ok = True def verify_load(self): - return self.load_cnt == 1 + return self.weight.load_ok def _native_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty @@ -84,18 +84,19 @@ def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name def _create_weight(self): self.weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) self.bias: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) - self.load_cnt = 0 + self.weight.load_ok = False + self.bias.load_ok = False def load_hf_weights(self, weights: Dict[str, torch.Tensor]): if self.weight_name in weights: self.weight.copy_(weights[self.weight_name]) - self.load_cnt += 1 + self.weight.load_ok = True if self.bias_name in weights: self.bias.copy_(weights[self.bias_name]) - self.load_cnt += 1 + self.bias.load_ok = True def verify_load(self): - return self.load_cnt == 2 + return self.weight.load_ok and self.bias.load_ok def _native_forward( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty @@ -175,7 +176,7 @@ def load_hf_weights(self, weights): self.weight[:, end - start].copy_(t_weight[start:end].to(self.data_type_)) # the padding part is zero self.weight[:, end:].zero_() - self.load_cnt += 1 + self.weight.load_ok = True class NoTpGEMMANormWeight(RMSNormWeight): diff --git a/lightllm/common/fused_moe/__init__.py b/lightllm/common/basemodel/triton_kernel/fused_moe/__init__.py similarity index 100% rename from lightllm/common/fused_moe/__init__.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/__init__.py diff --git a/lightllm/common/fused_moe/deepep_scatter_gather.py b/lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py similarity index 100% rename from lightllm/common/fused_moe/deepep_scatter_gather.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/deepep_scatter_gather.py diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py similarity index 100% rename from lightllm/common/fused_moe/grouped_fused_moe.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py similarity index 96% rename from lightllm/common/fused_moe/grouped_fused_moe_ep.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2a577890b2..2c6d013bd5 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -6,13 +6,15 @@ from typing import Any, Callable, Dict, Optional, Tuple import torch.distributed as dist from lightllm.utils.log_utils import init_logger -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import ( + silu_and_mul_masked_post_quant_fwd, +) from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( per_token_group_quant_fp8, tma_align_input_scale, ) -from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather +from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank from lightllm.common.triton_utils.autotuner import Autotuner import numpy as np diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py similarity index 100% rename from lightllm/common/fused_moe/grouped_topk.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py diff --git a/lightllm/common/fused_moe/moe_kernel_configs.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py similarity index 100% rename from lightllm/common/fused_moe/moe_kernel_configs.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul_config.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py diff --git a/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py similarity index 100% rename from lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py diff --git a/lightllm/common/fused_moe/moe_sum_recude_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py similarity index 100% rename from lightllm/common/fused_moe/moe_sum_recude_config.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py diff --git a/lightllm/common/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py similarity index 100% rename from lightllm/common/fused_moe/moe_sum_reduce.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py diff --git a/lightllm/common/fused_moe/softmax_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/softmax_topk.py similarity index 100% rename from lightllm/common/fused_moe/softmax_topk.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/softmax_topk.py diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py similarity index 96% rename from lightllm/common/fused_moe/topk_select.py rename to lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 5206800efc..72c3a381ed 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -22,7 +22,7 @@ from lightllm.utils.sgl_utils import sgl_ops from lightllm.utils.light_utils import light_ops from typing import Callable, List, Optional, Tuple -from lightllm.common.fused_moe.softmax_topk import softmax_topk +from lightllm.common.basemodel.triton_kernel.fused_moe.softmax_topk import softmax_topk from lightllm.common.triton_utils.autotuner import Autotuner use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"] @@ -177,8 +177,8 @@ def select_experts( scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, ): - from lightllm.common.fused_moe.topk_select import fused_topk - from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk + from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import fused_topk + from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_topk import triton_grouped_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 77e59465ee..4350307f1f 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -18,6 +18,10 @@ def get_expert(self, expert_idx: int): weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + def initialize_load_status(self, weight_num: int): + initial_loaded_status = [False, self.weight_scale is None, self.weight_zero_point is None] + self.load_ok = [initial_loaded_status.copy() for _ in range(weight_num)] + class QuantizationMethod(ABC): def __init__(self): diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index d5c96f821a..d606d757c1 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -136,9 +136,9 @@ def get_group(self, group_index: int) -> CustomProcessGroup: return self.groups[group_index] def new_deepep_group(self, n_routed_experts, hidden_size): - moe_mode = os.getenv("MOE_MODE", "TP") + enable_ep_moe = get_env_start_args().enable_ep_moe num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() - if moe_mode == "TP": + if not enable_ep_moe: self.ep_buffer = None return assert HAS_DEEPEP, "deep_ep is required for expert parallelism" diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 801ab6aba2..e1e435cce4 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -63,8 +63,8 @@ def _bind_func(self): def _bind_ffn(self): if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": + enable_ep_moe = get_env_start_args().enable_ep_moe + if enable_ep_moe: self._ffn = partial(Deepseek2TransformerLayerInfer._moe_ffn_edp, self) self._tpsp_ffn = self._tpsp_ffn_ep else: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 1e8d572e15..783e70e64b 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -9,8 +9,7 @@ ROWBMMWeight, COLMMWeight, RMSNormWeight, - FusedMoeWeightEP, - create_tp_moe_wegiht_obj, + FusedMoeWeight, ) from ..triton_kernel.weight_dequant import weight_dequant @@ -39,9 +38,8 @@ def _parse_config(self): self.kv_lora_rank = self.network_config_["kv_lora_rank"] self.num_fused_shared_experts = 0 if get_env_start_args().enable_fused_shared_experts and self.is_moe: - # MOE_MODE 处于 TP 模式下才能使能 enable_fused_shared_experts - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode == "TP" + # enable_fused_shared_experts can only work with tensor parallelism + assert not get_env_start_args().enable_ep_moe, "enable_fused_shared_experts can only work with tp mode." self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) self.n_embed = self.network_config_["hidden_size"] self.n_inter = self.network_config_["intermediate_size"] @@ -97,7 +95,6 @@ def load_hf_weights(self, weights): weight_scale_suffix = None if self.quant_cfg.quantized_weight: weight_scale_suffix = kv_b_quant_method.weight_scale_suffix - if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] # for deepseek_v3, the bmm operator is not quantized @@ -187,9 +184,9 @@ def _init_qkvo(self): ) def _load_mlp(self, mlp_prefix, is_shared_experts=False): - moe_mode = os.getenv("MOE_MODE", "TP") + enable_ep_moe = get_env_start_args().enable_ep_moe mlp_inter = self.moe_inter if is_shared_experts else self.n_inter - if self.is_moe and moe_mode == "EP": + if self.is_moe and enable_ep_moe: self.gate_up_proj = ROWMMWeight( in_dim=self.n_embed, out_dims=[mlp_inter, mlp_inter], @@ -243,38 +240,21 @@ def _init_moe(self): # == 0 时,说明不存在融合共享专家,共享专家单独加载和进行推理。 if self.num_fused_shared_experts == 0: self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", is_shared_experts=True) - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["EP", "TP"] - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name=self.e_score_correction_bias_name, - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - num_fused_shared_experts=self.num_fused_shared_experts, - split_inter_size=moe_intermediate_size // self.tp_world_size_, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - elif moe_mode == "EP": - self.experts = FusedMoeWeightEP( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name=self.e_score_correction_bias_name, - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name=self.e_score_correction_bias_name, + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.n_embed, + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + num_fused_shared_experts=self.num_fused_shared_experts, + layer_num=self.layer_num_, + network_config=self.network_config_, + ) def _init_ffn(self): self._load_mlp(f"model.layers.{self.layer_num_}.mlp") diff --git a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py index 0e7f4c8732..e6d58c3b2f 100644 --- a/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gpt_oss/layer_weights/transformer_layer_weight.py @@ -9,6 +9,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import TpAttSinkWeight from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -25,10 +26,10 @@ def __init__( return def _init_moe(self): - moe_mode = os.getenv("MOE_MODE", "TP") + enable_ep_moe = get_env_start_args().enable_ep_moe moe_intermediate_size = self.network_config_["intermediate_size"] n_routed_experts = self.network_config_["num_local_experts"] - assert moe_mode in ["TP"], "For now, GPT-OSS type model only support MOE TP mode." + assert not enable_ep_moe, "For now, GPT-OSS type model only support MOE TP mode." self.moe_gate = ROWMMWeight( in_dim=self.n_embed, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 820c5efa0d..dc6f10be59 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -4,7 +4,7 @@ from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor diff --git a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py index fa20a63f9c..51c62fd4cb 100644 --- a/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/mixtral/layer_weights/transformer_layer_weight.py @@ -2,7 +2,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -31,7 +32,6 @@ def _init_ffn(self): def _init_moe(self): inter_size = self.network_config_["intermediate_size"] - split_inter_size = inter_size // self.tp_world_size_ self.moe_gate = ROWMMWeight( in_dim=self.n_embed, @@ -43,25 +43,18 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, # no tensor parallelism ) - - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["TP"], f"Unsupported moe mode: {moe_mode}" - - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="w1", - down_proj_name="w2", - up_proj_name="w3", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts", - n_routed_experts=self.n_routed_experts, - split_inter_size=split_inter_size, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - num_fused_shared_experts=0, - hidden_size=self.network_config_.get("hidden_size"), - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + assert get_env_start_args().enable_ep_moe, "Mixtral only support tp mode." + self.experts = FusedMoeWeight( + gate_proj_name="w1", + down_proj_name="w2", + up_proj_name="w3", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.n_embed, + moe_intermediate_size=inter_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 52f9289eb1..71b16cb34b 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -14,6 +14,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_global_world_size from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor +from lightllm.utils.envs_utils import get_env_start_args logger = init_logger(__name__) @@ -41,8 +42,8 @@ def _bind_func(self): def _bind_ffn(self): if self.is_moe: - moe_mode = os.environ.get("MOE_MODE", "TP") - if moe_mode == "EP": + enable_ep_moe = get_env_start_args().enable_ep_moe + if enable_ep_moe: self._ffn = partial(Qwen3MOETransformerLayerInfer._moe_ffn_edp, self) self._tpsp_ffn = self._tpsp_ffn_ep else: diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 54cf7f02db..a889609d7f 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -1,6 +1,6 @@ import os from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeight class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): @@ -52,35 +52,17 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) - moe_mode = os.getenv("MOE_MODE", "TP") - assert moe_mode in ["EP", "TP"] - if moe_mode == "TP": - self.experts = create_tp_moe_wegiht_obj( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - split_inter_size=moe_intermediate_size // self.tp_world_size_, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - num_fused_shared_experts=0, - ) - elif moe_mode == "EP": - self.experts = FusedMoeWeightEP( - gate_proj_name="gate_proj", - down_proj_name="down_proj", - up_proj_name="up_proj", - e_score_correction_bias_name="", - weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", - n_routed_experts=self.n_routed_experts, - data_type=self.data_type_, - network_config=self.network_config_, - layer_num=self.layer_num_, - quant_cfg=self.quant_cfg, - ) - else: - raise ValueError(f"Unsupported moe mode: {moe_mode}") + self.experts = FusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", + n_routed_experts=self.n_routed_experts, + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=moe_intermediate_size, + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + ) diff --git a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py index 48ddf52089..83c05ba264 100644 --- a/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py +++ b/lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py @@ -1,6 +1,5 @@ import os from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 8bcbe3358e..54ad367867 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -1,5 +1,4 @@ import os -from turtle import TPen import torch import math import numpy as np diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f30d4cdf77..82ff934554 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -516,6 +516,11 @@ def make_argument_parser() -> argparse.ArgumentParser: " Therefore, it is recommended to set this parameter according to actual needs." ), ) + parser.add_argument( + "--enable_ep_moe", + action="store_true", + help="""Whether to enable ep moe for deepseekv3 model.""", + ) parser.add_argument( "--ep_redundancy_expert_config_path", type=str, @@ -530,7 +535,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_fused_shared_experts", action="store_true", - help="""Whether to enable fused shared experts for deepseekv3 model. only work when MOE_MODE=TP """, + help="""Whether to enable fused shared experts for deepseekv3 model. only work when tensor parallelism""", ) parser.add_argument( "--mtp_mode", diff --git a/test/start_scripts/README.md b/test/start_scripts/README.md index e00af27139..8ed44a2753 100644 --- a/test/start_scripts/README.md +++ b/test/start_scripts/README.md @@ -99,7 +99,6 @@ sh multi_pd_master/pd_decode.sh ### Environment Variables - `LOADWORKER`: Model loading thread count, recommended 8-18 -- `MOE_MODE`: Expert parallelism mode, set to EP to enable expert parallelism - `DISABLE_KV_TRANS_USE_P2P`: Disable P2P communication optimization to transfer kv data - `CUDA_VISIBLE_DEVICES`: Specify GPU devices to use @@ -108,6 +107,7 @@ sh multi_pd_master/pd_decode.sh - `--model_dir`: Model file path - `--tp`: Tensor parallelism degree - `--dp`: Data parallelism degree +- `--enable_ep_mode`: enable expert parallel - `--nnodes`: Total number of nodes - `--node_rank`: Current node rank - `--nccl_host`: NCCL communication host address diff --git a/test/start_scripts/multi_node_ep_node0.sh b/test/start_scripts/multi_node_ep_node0.sh index 68f80b39d5..2e46f70ac2 100644 --- a/test/start_scripts/multi_node_ep_node0.sh +++ b/test/start_scripts/multi_node_ep_node0.sh @@ -2,7 +2,7 @@ # nccl_host: the ip of the nccl host # sh multi_node_ep_node0.sh export nccl_host=$1 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -10,7 +10,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ ---nccl_port 2732 +--nccl_port 2732 --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap #--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_node_ep_node1.sh b/test/start_scripts/multi_node_ep_node1.sh index 10aee85285..a8ce02a8b6 100644 --- a/test/start_scripts/multi_node_ep_node1.sh +++ b/test/start_scripts/multi_node_ep_node1.sh @@ -2,7 +2,7 @@ # nccl_host: the ip of the nccl host # sh multi_node_ep_node1.sh export nccl_host=$1 -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -10,7 +10,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ ---nccl_port 2732 +--nccl_port 2732 --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap #--enable_decode_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/multi_pd_master/pd_prefill.sh b/test/start_scripts/multi_pd_master/pd_prefill.sh index eaa343ef62..4087c16557 100644 --- a/test/start_scripts/multi_pd_master/pd_prefill.sh +++ b/test/start_scripts/multi_pd_master/pd_prefill.sh @@ -5,7 +5,7 @@ export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -16,6 +16,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ ---config_server_port 60088 +--config_server_port 60088 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ No newline at end of file diff --git a/test/start_scripts/single_node_ep.sh b/test/start_scripts/single_node_ep.sh index 7406d94628..05798a2049 100644 --- a/test/start_scripts/single_node_ep.sh +++ b/test/start_scripts/single_node_ep.sh @@ -1,9 +1,10 @@ # H200 single node deepseek R1 dpep mode -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ +LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ ---llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 +--llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ #--enable_decode_microbatch_overlap \ diff --git a/test/start_scripts/single_pd_master/pd_decode.sh b/test/start_scripts/single_pd_master/pd_decode.sh index 36804dd11e..5471f1d84d 100644 --- a/test/start_scripts/single_pd_master/pd_decode.sh +++ b/test/start_scripts/single_pd_master/pd_decode.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -14,6 +14,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --port 8121 \ --nccl_port 12322 \ --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ +--enable_ep_moe \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_decode.sh b/test/start_scripts/single_pd_master/pd_nixl_decode.sh index 5fb34a973e..9354661d3e 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_decode.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_decode.sh @@ -10,7 +10,7 @@ export UCX_LOG_LEVEL=info export UCX_TLS=rc,cuda,gdr_copy nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_decode" \ --tp 8 \ @@ -19,6 +19,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --port 8121 \ --nccl_port 12322 \ --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ +--enable_ep_moe \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 # if you want to enable microbatch overlap, you can uncomment the following lines diff --git a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh index 5a37df0b1d..ea3188be39 100644 --- a/test/start_scripts/single_pd_master/pd_nixl_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_nixl_prefill.sh @@ -11,7 +11,7 @@ export UCX_TLS=rc,cuda,gdr_copy export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "nixl_prefill" \ --tp 8 \ @@ -20,6 +20,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --port 8019 \ --nccl_port 2732 \ --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ +--enable_ep_moe \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ --pd_master_port 60011 diff --git a/test/start_scripts/single_pd_master/pd_prefill.sh b/test/start_scripts/single_pd_master/pd_prefill.sh index b94a1f8ccd..131b3754e6 100644 --- a/test/start_scripts/single_pd_master/pd_prefill.sh +++ b/test/start_scripts/single_pd_master/pd_prefill.sh @@ -5,7 +5,7 @@ export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d -MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ +LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -16,6 +16,7 @@ MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ --llm_prefill_att_backend fa3 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ ---pd_master_port 60011 +--pd_master_port 60011 \ +--enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap \ No newline at end of file From e8463f105a1d01acf98ad38a408f6a187d9a1647 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 22 Jan 2026 17:33:55 +0000 Subject: [PATCH 40/43] redunancy_expert(draft) --- .../source/tutorial/deepseek_deployment.rst | 35 +++--- .../source/tutorial/deepseek_deployment.rst | 35 +++--- ...ight_ep_redundancy.py => ep_redundancy.py} | 0 .../fused_moe/fused_moe_weight.py | 110 +++++++++++++----- .../fused_moe/gpt_oss_fused_moe_weight_tp.py | 74 ++---------- .../meta_weights/fused_moe/impl/base_impl.py | 19 ++- .../fused_moe/impl/triton_impl.py | 14 ++- .../mode_backend/redundancy_expert_manager.py | 6 +- 8 files changed, 156 insertions(+), 137 deletions(-) rename lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/{fused_moe_weight_ep_redundancy.py => ep_redundancy.py} (100%) diff --git a/docs/CN/source/tutorial/deepseek_deployment.rst b/docs/CN/source/tutorial/deepseek_deployment.rst index 5d57b137c6..720fe10e0b 100644 --- a/docs/CN/source/tutorial/deepseek_deployment.rst +++ b/docs/CN/source/tutorial/deepseek_deployment.rst @@ -53,15 +53,16 @@ LightLLM 支持以下几种部署模式: .. code-block:: bash # H200 单机 DeepSeek-R1 DP + EP 模式 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ --llm_prefill_att_backend fa3 \ - --llm_decode_att_backend fa3 + --llm_decode_att_backend fa3 \ + --enable_ep_moe **参数说明:** -- `MOE_MODE=EP`: 设置专家并行模式 +- `--enable_ep_moe`: 设置专家并行模式 - `--tp 8`: 张量并行度 - `--dp 8`: 数据并行度,通常设置为与 tp 相同的值 - `--llm_prefill_att_backend fa3`: 启用 Flash Attention 3.0 @@ -131,7 +132,7 @@ LightLLM 支持以下几种部署模式: # H200 多机 DeepSeek-R1 EP 模式 Node 0 # 使用方法: sh multi_node_ep_node0.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -140,7 +141,7 @@ LightLLM 支持以下几种部署模式: --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Node 1 启动命令:** @@ -149,7 +150,7 @@ LightLLM 支持以下几种部署模式: # H200 多机 DeepSeek-R1 EP 模式 Node 1 # 使用方法: sh multi_node_ep_node1.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -158,7 +159,7 @@ LightLLM 支持以下几种部署模式: --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **可选优化参数:** - `--enable_prefill_microbatch_overlap`: 启用预填充微批次重叠 @@ -195,7 +196,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -207,7 +208,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_prefill_microbatch_overlap @@ -220,7 +222,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -232,7 +234,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_decode_microbatch_overlap @@ -289,7 +292,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -301,7 +304,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_prefill_microbatch_overlap @@ -309,7 +313,7 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --host $host \ @@ -320,7 +324,8 @@ PD (Prefill-Decode) 分离模式将预填充和解码阶段分离部署,可以 --llm_prefill_att_backend fa3 \ --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # 如果需要启用微批次重叠,可以取消注释以下行 #--enable_decode_microbatch_overlap diff --git a/docs/EN/source/tutorial/deepseek_deployment.rst b/docs/EN/source/tutorial/deepseek_deployment.rst index accdbc462b..c908490afb 100755 --- a/docs/EN/source/tutorial/deepseek_deployment.rst +++ b/docs/EN/source/tutorial/deepseek_deployment.rst @@ -52,15 +52,16 @@ Suitable for expert parallelism deployment of MoE models like DeepSeek-V2/V3. .. code-block:: bash # H200 Single node DeepSeek-R1 DP + EP Mode - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 8 \ --dp 8 \ --llm_prefill_att_backend fa3 \ - --llm_decode_att_backend fa3 + --llm_decode_att_backend fa3 \ + --enable_ep_moe **Parameter Description:** -- `MOE_MODE=EP`: Set expert parallelism mode +- `--enable_ep_moe`: Set expert parallelism mode - `--tp 8`: Tensor parallelism - `--dp 8`: Data parallelism, usually set to the same value as tp @@ -128,7 +129,7 @@ Suitable for deploying MoE models across multiple nodes. # H200 Multi-node DeepSeek-R1 EP Mode Node 0 # Usage: sh multi_node_ep_node0.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -137,7 +138,7 @@ Suitable for deploying MoE models across multiple nodes. --nnodes 2 \ --node_rank 0 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Node 1 Launch Command:** @@ -146,7 +147,7 @@ Suitable for deploying MoE models across multiple nodes. # H200 Multi-node DeepSeek-R1 EP Mode Node 1 # Usage: sh multi_node_ep_node1.sh export nccl_host=$1 - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ + LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ --model_dir /path/DeepSeek-R1 \ --tp 16 \ --dp 16 \ @@ -155,7 +156,7 @@ Suitable for deploying MoE models across multiple nodes. --nnodes 2 \ --node_rank 1 \ --nccl_host $nccl_host \ - --nccl_port 2732 + --nccl_port 2732 --enable_ep_moe **Optional Optimization Parameters:** - `--enable_prefill_microbatch_overlap`: Enable prefill microbatch overlap @@ -192,7 +193,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --tp 8 \ @@ -203,7 +204,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --llm_prefill_att_backend fa3 \ --llm_decode_att_backend fa3 \ --disable_cudagraph \ - --pd_master_ip $pd_master_ip + --pd_master_ip $pd_master_ip \ + --enable_ep_moe **Step 3: Launch Decode Service** @@ -214,7 +216,7 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for export host=$1 export pd_master_ip=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --tp 8 \ @@ -226,7 +228,8 @@ PD (Prefill-Decode) disaggregation mode separates prefill and decode stages for --llm_decode_att_backend fa3 \ --disable_cudagraph \ --pd_master_ip $pd_master_ip \ - --pd_master_port 60011 + --pd_master_port 60011 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_decode_microbatch_overlap @@ -283,7 +286,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "prefill" \ --host $host \ @@ -295,7 +298,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --llm_decode_att_backend fa3 \ --disable_cudagraph \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_prefill_microbatch_overlap @@ -303,7 +307,7 @@ Supports multiple PD Master nodes, providing better load balancing and high avai export host=$1 export config_server_host=$2 nvidia-cuda-mps-control -d - MOE_MODE=EP LOADWORKER=18 python -m lightllm.server.api_server \ + LOADWORKER=18 python -m lightllm.server.api_server \ --model_dir /path/DeepSeek-R1 \ --run_mode "decode" \ --host $host \ @@ -314,7 +318,8 @@ Supports multiple PD Master nodes, providing better load balancing and high avai --llm_prefill_att_backend fa3 \ --llm_decode_att_backend fa3 \ --config_server_host $config_server_host \ - --config_server_port 60088 + --config_server_port 60088 \ + --enable_ep_moe # if you want to enable microbatch overlap, you can uncomment the following lines #--enable_decode_microbatch_overlap diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py similarity index 100% rename from lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep_redundancy.py rename to lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/ep_redundancy.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8b01f46437..ced1e9267b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -11,6 +11,9 @@ from lightllm.common.quantization.quantize_method import QuantizationMethod from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num, get_env_start_args from lightllm.utils.dist_utils import get_global_world_size, get_global_rank +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class FusedMoeWeight(BaseWeightTpl): @@ -53,13 +56,17 @@ def __init__( self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts self._init_config(network_config) + self._init_redundancy_expert_params() self._init_parallel_params() self.fuse_moe_impl = select_fuse_moe_impl(self.quant_method, self.enable_ep_moe)( n_routed_experts=self.n_routed_experts, num_fused_shared_experts=self.num_fused_shared_experts, - redundancy_expert_num=self.redundancy_expert_num, routed_scaling_factor=self.routed_scaling_factor, quant_method=self.quant_method, + redundancy_expert_num=self.redundancy_expert_num, + redundancy_expert_ids_tensor=self.redundancy_expert_ids_tensor, + routed_expert_counter_tensor=self.routed_expert_counter_tensor, + auto_update_redundancy_expert=self.auto_update_redundancy_expert, ) self.lock = threading.Lock() self._create_weight() @@ -73,18 +80,40 @@ def _init_config(self, network_config: Dict[str, Any]): self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.scoring_func = network_config.get("scoring_func", "softmax") + def _init_redundancy_expert_params(self): + self.redundancy_expert_num = get_redundancy_expert_num() + self.redundancy_expert_ids = get_redundancy_expert_ids(self.layer_num_) + self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert + self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") + self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") + def _init_parallel_params(self): self.local_n_routed_experts = self.n_routed_experts + self.num_fused_shared_experts - self.start_expert_id = 0 self.split_inter_size = self.moe_intermediate_size // self.tp_world_size_ - self.redundancy_expert_num = 0 if self.enable_ep_moe: assert self.num_fused_shared_experts == 0, "num_fused_shared_experts must be 0 when enable_ep_moe" - self.redundancy_expert_num = get_redundancy_expert_num() - self.redundancy_expert_ids = get_redundancy_expert_ids(self.layer_num_) + logger.info( + f"global_rank {self.global_rank_} layerindex {self.layer_num_} " + f"redundancy_expertids: {self.redundancy_expert_ids}" + ) self.local_n_routed_experts = self.n_routed_experts // self.global_world_size + self.redundancy_expert_num - self.start_expert_id = self.global_rank_ * self.n_routed_experts // self.global_world_size self.split_inter_size = self.moe_intermediate_size + n_experts_per_rank = self.n_routed_experts // self.global_world_size + start_expert_id = self.global_rank_ * n_experts_per_rank + self.local_expert_ids = ( + list(range(start_expert_id, start_expert_id + n_experts_per_rank)) + self.redundancy_expert_ids + ) + self.expert_idx_to_local_idx = { + expert_idx: expert_idx - start_expert_id for expert_idx in self.local_expert_ids[:n_experts_per_rank] + } + self.redundancy_expert_idx_to_local_idx = { + redundancy_expert_idx: n_experts_per_rank + i + for (i, redundancy_expert_idx) in enumerate(self.redundancy_expert_ids) + } + else: + self.local_expert_ids = list(range(self.n_routed_experts + self.num_fused_shared_experts)) + self.expert_idx_to_local_idx = {expert_idx: i for (i, expert_idx) in enumerate(self.local_expert_ids)} + self.rexpert_idx_to_local_idx = {} def experts( self, @@ -229,25 +258,12 @@ def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) - - # Load each expert with TP slicing - for i_experts in range(self.start_expert_id, self.start_expert_id + self.local_n_routed_experts): - with self.lock: - self._load_expert(i_experts, weights, type="weight", suffix=self.quant_method.weight_suffix) - if self.w13.weight_scale is not None: - with self.lock: - self._load_expert( - i_experts, weights, type="weight_scale", suffix=self.quant_method.weight_scale_suffix - ) - if self.w13.weight_zero_point is not None: - with self.lock: - self._load_expert( - i_experts, weights, type="weight_zero_point", suffix=self.quant_method.weight_zero_point_suffix - ) + self._load_weight(self.expert_idx_to_local_idx, weights) + if self.redundancy_expert_num > 0: + self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) def verify_load(self): return True - return self.load_cnt == self.n_routed_experts * 3 * 2 def _create_weight(self): intermediate_size = self.split_inter_size @@ -276,31 +292,61 @@ def _create_weight(self): ) self.load_cnt = 0 - def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): - if self.quant_method.weight_need_quanted(weight): - self.quant_method.quantize(weight, weight_pack, start_idx) - else: - self.quant_method.load_weight(weight, weight_pack, start_idx) + def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - def _load_expert(self, expert_idx, weights, type: str, suffix: str = "weight"): + # Load each expert with TP slicing + for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): + with self.lock: + self._load_expert( + expert_idx, local_expert_idx, weights, type="weight", suffix=self.quant_method.weight_suffix + ) + if self.w13.weight_scale is not None: + with self.lock: + self._load_expert( + expert_idx, + local_expert_idx, + weights, + type="weight_scale", + suffix=self.quant_method.weight_scale_suffix, + ) + if self.w13.weight_zero_point is not None: + with self.lock: + self._load_expert( + expert_idx, + local_expert_idx, + weights, + type="weight_zero_point", + suffix=self.quant_method.weight_zero_point_suffix, + ) + + def _load_expert( + self, + expert_idx: int, + local_expert_idx: int, + weights: Dict[str, torch.Tensor], + type: str, + suffix: str = "weight", + ): w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{suffix}" intermediate_size = self.split_inter_size load_func, slice_func = self._get_load_and_slice_func(type, is_row=True) - local_expert_idx = expert_idx - self.start_expert_id if w1_weight in weights: load_func(slice_func(weights[w1_weight]), self.w13.get_expert(local_expert_idx), start_idx=0) - self.load_cnt += 1 if w3_weight in weights: load_func( slice_func(weights[w3_weight]), self.w13.get_expert(local_expert_idx), start_idx=intermediate_size ) - self.load_cnt += 1 load_func, slice_func = self._get_load_and_slice_func(type, is_row=False) if w2_weight in weights: load_func(slice_func(weights[w2_weight]), self.w2.get_expert(local_expert_idx), start_idx=0) - self.load_cnt += 1 + + def _load_weight_func(self, weight: torch.Tensor, weight_pack: WeightPack, start_idx: int = 0): + if self.quant_method.weight_need_quanted(weight): + self.quant_method.quantize(weight, weight_pack, start_idx) + else: + self.quant_method.load_weight(weight, weight_pack, start_idx) def _get_load_and_slice_func(self, type: str, is_row: bool = True): if is_row: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index f3f153b0ab..e7748b1dfd 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -1,9 +1,12 @@ +import os import torch -from typing import Dict, Any +import threading +from typing import Optional, Tuple, List, Dict, Any from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id +from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger -from lightllm.common.quantization.quantize_method import QuantizationMethod logger = init_logger(__name__) @@ -41,7 +44,7 @@ def __init__( network_config: Dict[str, Any], layer_num: int, world_size: int = 1, # diff with FusedMoeWeightTP - quant_method: QuantizationMethod = None, + quant_cfg: Quantcfg = None, ) -> None: super().__init__( gate_up_proj_name, @@ -55,7 +58,7 @@ def __init__( data_type, network_config, layer_num, - quant_method, + quant_cfg, ) self.hidden_size = network_config["hidden_size"] @@ -118,56 +121,7 @@ def router(self, router_logits, top_k): router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_top_value, router_indices - def _native_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """PyTorch native implementation for GPT-OSS MoE forward pass.""" - topk_weights, topk_ids = self.router(router_logits, top_k) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - - batch_size, hidden_size = input_tensor.shape - - output = torch.zeros_like(input_tensor) - input_bf16 = input_tensor.to(torch.bfloat16) - - for i in range(batch_size): - expert_output = torch.zeros(hidden_size, dtype=torch.bfloat16, device=input_tensor.device) - for j in range(top_k): - expert_idx = topk_ids[i, j].item() - weight = topk_weights[i, j] - - w1_expert = w1[expert_idx] - w2_expert = w2[expert_idx] - - x = input_bf16[i : i + 1] - hidden = torch.mm(x, w1_expert.T) # [1, intermediate_size * 2] - if self.w1_bias is not None: - hidden = hidden + self.w1_bias[expert_idx : expert_idx + 1] - - gate = hidden[:, 0::2] - up = hidden[:, 1::2] - - gate = torch.clamp(gate * self.alpha, -self.limit, self.limit) - gate = torch.nn.functional.sigmoid(gate) - hidden = gate * up - - expert_out = torch.mm(hidden, w2_expert.T) - if self.w2_bias is not None: - expert_out = expert_out + self.w2_bias[expert_idx : expert_idx + 1] / self.tp_world_size_ - - expert_output += weight * expert_out.squeeze(0) - - output[i] = expert_output - - input_tensor.copy_(output.to(input_tensor.dtype)) - return output - - def _cuda_forward( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """CUDA optimized implementation for GPT-OSS MoE forward pass.""" + def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): topk_weights, topk_ids = self.router(router_logits, top_k) w1, w1_scale = self.w1 @@ -194,18 +148,6 @@ def _cuda_forward( ) return output_tensor - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): - """Backward compatible method that routes to platform-specific implementation.""" - return self._forward( - input_tensor=input_tensor, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - ) - def _convert_moe_packed_tensors( self, blocks, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 2f5d169eb1..c56cd4da31 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -16,20 +16,31 @@ def __init__( self, n_routed_experts: int, num_fused_shared_experts: int, - redundancy_expert_num: int, routed_scaling_factor: float, quant_method: QuantizationMethod, + redundancy_expert_num: int, + redundancy_expert_ids_tensor: torch.Tensor, + routed_expert_counter_tensor: torch.Tensor, + auto_update_redundancy_expert: bool, ): self.n_routed_experts = n_routed_experts self.num_fused_shared_experts = num_fused_shared_experts - self.redundancy_expert_num = redundancy_expert_num self.routed_scaling_factor = routed_scaling_factor self.quant_method = quant_method self.global_rank_ = get_global_rank() - self.global_world_size = get_global_world_size() + self.global_world_size_ = get_global_world_size() + self.ep_n_routed_experts = self.n_routed_experts // self.global_world_size_ self.total_expert_num_contain_redundancy = ( - self.n_routed_experts + self.redundancy_expert_num * self.global_world_size + self.n_routed_experts + redundancy_expert_num * self.global_world_size_ ) + + # redundancy expert related + self.redundancy_expert_num = redundancy_expert_num + self.redundancy_expert_ids_tensor = redundancy_expert_ids_tensor + self.routed_expert_counter_tensor = routed_expert_counter_tensor + self.auto_update_redundancy_expert = auto_update_redundancy_expert + + # workspace for kernel optimization self.workspace = self.create_workspace() @abstractmethod diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 9965246a23..8bcdb4bf90 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -10,12 +10,22 @@ def __init__( self, n_routed_experts: int, num_fused_shared_experts: int, - redundancy_expert_num: int, routed_scaling_factor: float, quant_method: QuantizationMethod, + redundancy_expert_num: int, + redundancy_expert_ids_tensor: torch.Tensor, + routed_expert_counter_tensor: torch.Tensor, + auto_update_redundancy_expert: bool, ): super().__init__( - n_routed_experts, num_fused_shared_experts, redundancy_expert_num, routed_scaling_factor, quant_method + n_routed_experts=n_routed_experts, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + quant_method=quant_method, + redundancy_expert_num=redundancy_expert_num, + redundancy_expert_ids_tensor=redundancy_expert_ids_tensor, + routed_expert_counter_tensor=routed_expert_counter_tensor, + auto_update_redundancy_expert=auto_update_redundancy_expert, ) def create_workspace(self): diff --git a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py index e3a71379d2..596eca4f24 100644 --- a/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/redundancy_expert_manager.py @@ -8,10 +8,10 @@ import json from typing import List from lightllm.common.basemodel.basemodel import TpPartBaseModel -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep_redundancy import ( +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.ep_redundancy import ( FusedMoeWeightEPAutoRedundancy, ) -from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight_ep import FusedMoeWeightEP +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight from lightllm.utils.envs_utils import get_env_start_args, get_redundancy_expert_update_interval from lightllm.utils.envs_utils import get_redundancy_expert_update_max_load_count from lightllm.utils.envs_utils import get_redundancy_expert_num @@ -28,7 +28,7 @@ def __init__(self, model: TpPartBaseModel): self.model = model self.ep_fused_moeweights: List[FusedMoeWeightEPAutoRedundancy] = [] for layer in self.model.trans_layers_weight: - ep_weights = self._find_members_of_class(layer, FusedMoeWeightEP) + ep_weights = self._find_members_of_class(layer, FusedMoeWeight) assert len(ep_weights) <= 1 self.ep_fused_moeweights.extend([FusedMoeWeightEPAutoRedundancy(e) for e in ep_weights]) From ec0fe0f5e7dff07c4d3fdb1fd13f147064fe2274 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Thu, 22 Jan 2026 17:43:03 +0000 Subject: [PATCH 41/43] remove weight_ep --- .../layer_weights/meta_weights/__init__.py | 1 - .../fused_moe/fused_moe_weight_ep.py | 692 ------------------ 2 files changed, 693 deletions(-) delete mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index ab0e5b6040..8e884012d5 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -9,5 +9,4 @@ from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight -from .fused_moe.fused_moe_weight_ep import FusedMoeWeightEP from .fused_moe.fused_moe_weight import FusedMoeWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py deleted file mode 100644 index 6659a98d4e..0000000000 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight_ep.py +++ /dev/null @@ -1,692 +0,0 @@ -import torch -import threading -from typing import Optional, Tuple, List, Dict, Any -from lightllm.utils.dist_utils import get_global_world_size, get_global_rank -from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl -from lightllm.common.basemodel.layer_weights.meta_weights.platform_op import PlatformAwareOp -from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import ( - fused_experts_impl, - masked_group_gemm, - _deepgemm_grouped_fp8_nt_contiguous, -) -from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd -from lightllm.distributed import dist_group_manager -from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts -from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank -from lightllm.utils.envs_utils import get_redundancy_expert_ids, get_redundancy_expert_num -from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( - per_token_group_quant_fp8, - tma_align_input_scale, -) -from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather -from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair -from lightllm.utils.log_utils import init_logger -from lightllm.common.triton_utils.autotuner import Autotuner -from lightllm.common.quantization.quantize_method import WeightPack - - -logger = init_logger(__name__) - - -class FusedMoeWeightEP(BaseWeightTpl, PlatformAwareOp): - def __init__( - self, - gate_proj_name: str, - down_proj_name: str, - up_proj_name: str, - e_score_correction_bias_name: str, - weight_prefix: str, - n_routed_experts: int, - data_type: torch.dtype, - network_config: Dict[str, Any], - layer_num: int, - quant_cfg=None, - hidden_size: Optional[int] = None, - ) -> None: - super().__init__() - - self.layer_num = layer_num - self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") - self.quantized_weight = quant_cfg.quantized_weight - if self.quant_method is not None: - self.weight_scale_suffix = self.quant_method.weight_scale_suffix - self.quant_method.is_moe = True - block_size = 1 - if hasattr(self.quant_method, "block_size"): - block_size = self.quant_method.block_size - self.block_size = block_size - - self.weight_prefix = weight_prefix - self.w1_weight_name = gate_proj_name - self.w2_weight_name = down_proj_name - self.w3_weight_name = up_proj_name - self.e_score_correction_bias_name = e_score_correction_bias_name - self.n_routed_experts = n_routed_experts - self.hidden_size = hidden_size - - global_world_size = get_global_world_size() - self.global_rank_ = get_global_rank() - self.redundancy_expert_num = get_redundancy_expert_num() - self.redundancy_expert_ids = get_redundancy_expert_ids(layer_num) - logger.info( - f"global_rank {self.global_rank_} layerindex {layer_num} redundancy_expertids: {self.redundancy_expert_ids}" - ) - self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") - self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") - self.total_expert_num_contain_redundancy = ( - self.n_routed_experts + self.redundancy_expert_num * global_world_size - ) - assert self.n_routed_experts % global_world_size == 0 - self.ep_n_routed_experts = self.n_routed_experts // global_world_size - ep_load_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - self.ep_load_expert_num = ep_load_expert_num - self.experts_up_projs = [None] * ep_load_expert_num - self.experts_gate_projs = [None] * ep_load_expert_num - self.experts_up_proj_scales = [None] * ep_load_expert_num - self.experts_gate_proj_scales = [None] * ep_load_expert_num - self.e_score_correction_bias = None - self.w2_list = [None] * ep_load_expert_num - self.w2_scale_list = [None] * ep_load_expert_num - self.scoring_func = network_config.get("scoring_func", "softmax") - self.w1 = [None, None] # weight, weight_scale - self.w2 = [None, None] # weight, weight_scale - self.use_fp8_w8a8 = self.quant_method is not None - network_config["n_group"] = network_config.get("n_group", 0) - self.num_experts_per_tok = network_config["num_experts_per_tok"] - self.use_grouped_topk = network_config["n_group"] > 0 - self.norm_topk_prob = network_config["norm_topk_prob"] - self.n_group = network_config["n_group"] - network_config["topk_group"] = network_config.get("topk_group", 0) - self.topk_group = network_config["topk_group"] - network_config["routed_scaling_factor"] = network_config.get("routed_scaling_factor", 1.0) - self.routed_scaling_factor = network_config["routed_scaling_factor"] - - self.lock = threading.Lock() - # init buffer - - # auto update redundancy expert vars - self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert - - # Pre-allocate memory if hidden_size is provided - if self.hidden_size is not None: - self._create_weight() - - def _create_weight(self): - """Pre-allocate GPU memory for fused MoE weights""" - if self.hidden_size is None: - return - - total_expert_num = self.ep_load_expert_num - # We need to determine intermediate size from network config or use a default - # This will be updated when first weight is loaded if needed - intermediate_size = getattr(self, "intermediate_size", None) - if intermediate_size is None: - # Default fallback - this will be corrected during load - intermediate_size = self.hidden_size * 4 - - if not self.quantized_weight and self.quant_method is not None: - # Quantized weights - w1_pack = self.quant_method.create_weight( - total_expert_num * intermediate_size * 2, - self.hidden_size, - dtype=self.data_type_, - device_id=self.device_id_, - ) - self.w1[0] = w1_pack.weight.view(total_expert_num, intermediate_size * 2, self.hidden_size) - self.w1[1] = w1_pack.weight_scale.view(total_expert_num, intermediate_size * 2, self.hidden_size) - - w2_pack = self.quant_method.create_weight( - total_expert_num * self.hidden_size, - intermediate_size, - dtype=self.data_type_, - device_id=self.device_id_, - ) - self.w2[0] = w2_pack.weight.view(total_expert_num, self.hidden_size, intermediate_size) - self.w2[1] = w2_pack.weight_scale.view(total_expert_num, self.hidden_size, intermediate_size) - else: - # Regular weights - self.w1[0] = torch.empty( - (total_expert_num, intermediate_size * 2, self.hidden_size), - dtype=self.data_type_, - device=f"cuda:{self.device_id_}", - ) - self.w2[0] = torch.empty( - (total_expert_num, self.hidden_size, intermediate_size), - dtype=self.data_type_, - device=f"cuda:{self.device_id_}", - ) - - def _select_experts( - self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ): - """Select experts and return topk weights and ids.""" - topk_weights, topk_ids = select_experts( - hidden_states=input_tensor, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_ids, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - return topk_weights, topk_ids - - def _cuda_forward( - self, - input_tensor, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - is_prefill, - ): - """CUDA optimized implementation for EP MoE forward pass.""" - topk_weights, topk_ids = self._select_experts( - input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group - ) - - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - return fused_experts_impl( - hidden_states=input_tensor, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_idx=topk_ids.to(torch.long), - num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy - buffer=dist_group_manager.ep_buffer, - is_prefill=is_prefill, - use_fp8_w8a8=self.use_fp8_w8a8, - use_fp8_all2all=self.use_fp8_w8a8, - use_int8_w8a16=False, # default to False - w1_scale=w1_scale, - w2_scale=w2_scale, - previous_event=None, # for overlap - ) - - def experts( - self, - input_tensor, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - is_prefill, - ): - """Backward compatible method that routes to platform-specific implementation.""" - return self._forward( - input_tensor=input_tensor, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - is_prefill=is_prefill, - ) - - def low_latency_dispatch( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ): - - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=self.use_grouped_topk, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - topk_group=self.topk_group, - num_expert_group=self.n_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_idx, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - - topk_idx = topk_idx.to(torch.long) - num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() - recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - self.total_expert_num_contain_redundancy, - use_fp8=self.use_fp8_w8a8, - async_finish=False, - return_recv_hook=True, - ) - return recv_x, masked_m, topk_idx, topk_weights, handle, hook - - def select_experts_and_quant_input( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ): - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - correction_bias=self.e_score_correction_bias, - use_grouped_topk=self.use_grouped_topk, - top_k=self.num_experts_per_tok, - renormalize=self.norm_topk_prob, - topk_group=self.topk_group, - num_expert_group=self.n_group, - scoring_func=self.scoring_func, - ) - topk_weights.mul_(self.routed_scaling_factor) - if self.redundancy_expert_num > 0: - redundancy_topk_ids_repair( - topk_ids=topk_idx, - redundancy_expert_ids=self.redundancy_expert_ids_tensor, - ep_expert_num=self.ep_n_routed_experts, - global_rank=self.global_rank_, - expert_counter=self.routed_expert_counter_tensor, - enable_counter=self.auto_update_redundancy_expert, - ) - M, K = hidden_states.shape - w1, w1_scale = self.w1 - block_size_k = 0 - if w1.ndim == 3: - block_size_k = w1.shape[2] // w1_scale.shape[2] - assert block_size_k == 128, "block_size_k must be 128" - qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype) - return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) - - def dispatch( - self, - qinput_tensor: Tuple[torch.Tensor], - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - overlap_event: Optional[Any] = None, - ): - buffer = dist_group_manager.ep_buffer - # get_dispatch_layout - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - topk_idx, - self.total_expert_num_contain_redundancy, - previous_event=overlap_event, - async_finish=True, - allocate_on_comm_stream=True, - ) - recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( - qinput_tensor, - topk_idx=topk_idx, - topk_weights=topk_weights, - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=previous_event, - async_finish=True, - allocate_on_comm_stream=True, - expert_alignment=128, - ) - - def hook(): - event.current_stream_wait() - - return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook - - def masked_group_gemm( - self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int - ): - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - return masked_group_gemm(recv_x, masked_m, dtype, w1, w1_scale, w2, w2_scale, expected_m=expected_m) - - def prefilled_group_gemm( - self, - num_recv_tokens_per_expert_list, - recv_x: Tuple[torch.Tensor], - recv_topk_idx: torch.Tensor, - recv_topk_weights: torch.Tensor, - hidden_dtype=torch.bfloat16, - ): - device = recv_x[0].device - w1, w1_scale = self.w1 - w2, w2_scale = self.w2 - _, K = recv_x[0].shape - _, N, _ = w1.shape - # scatter - all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. - # gather_out shape [recive_num_tokens, hidden] - gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype) - if all_tokens > 0: - input_tensor = [ - torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype), - torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32), - ] - # when m_indices is filled ok. - # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] - # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] - # ... - m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) - # output_index shape [recive_num_tokens, topk_num] - # output_index use to show the token index in input_tensor - output_index = torch.empty_like(recv_topk_idx) - - num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" - ).cuda(non_blocking=True) - - expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) - - ep_scatter( - recv_x[0], - recv_x[1], - recv_topk_idx, - num_recv_tokens_per_expert, - expert_start_loc, - input_tensor[0], - input_tensor[1], - m_indices, - output_index, - ) - input_tensor[1] = tma_align_input_scale(input_tensor[1]) - # groupgemm (contiguous layout) - gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) - - _deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) - - # silu_and_mul_fwd + qaunt - # TODO fused kernel - silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) - - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) - qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( - silu_out, self.block_size, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True - ) - - # groupgemm (contiguous layout) - gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) - - _deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices) - # gather and local reduce - ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) - else: - ######################################## warning ################################################## - # here is used to match autotune feature, make moe model run same triton kernel in different rank. - # in some special case, one rank will recv 0 token, so add a token to make it run triton kernel. - if Autotuner.is_autotune_warmup(): - _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) - _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) - _gemm_out_a, _silu_out = None, None - - return gather_out - - def low_latency_combine( - self, - gemm_out_b: torch.Tensor, - topk_idx: torch.Tensor, - topk_weights: torch.Tensor, - handle: Any, - ): - combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine( - gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True - ) - return combined_x, hook - - def combine( - self, - gemm_out_b: torch.Tensor, - handle: Any, - overlap_event: Optional[Any] = None, - ): - # normal combine - combined_x, _, event = dist_group_manager.ep_buffer.combine( - gemm_out_b, - handle, - topk_weights=None, - async_finish=True, - previous_event=overlap_event, - allocate_on_comm_stream=True, - ) - - def hook(): - event.current_stream_wait() - - return combined_x, hook - - def _fuse(self): - if self.quantized_weight: - self._fuse_weight_scale() - with self.lock: - if ( - hasattr(self, "experts_up_projs") - and None not in self.experts_up_projs - and None not in self.experts_gate_projs - and None not in self.w2_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape - up_out_dim, up_in_dim = self.experts_up_projs[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_projs[0].dtype - total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - - w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") - - for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num): - w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] - w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] - - inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] - w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) - if not self.quantized_weight and self.quant_method is not None: - qw1_pack = self.quant_method.quantize(w1) - qw2_pack = self.quant_method.quantize(w2) - self.w1[0] = qw1_pack.weight - self.w1[1] = qw1_pack.weight_scale - self.w2[0] = qw2_pack.weight - self.w2[1] = qw2_pack.weight_scale - else: - self.w1[0] = self._cuda(w1) - self.w2[0] = self._cuda(w2) - delattr(self, "w2_list") - delattr(self, "experts_up_projs") - delattr(self, "experts_gate_projs") - - def _fuse_weight_scale(self): - with self.lock: - if ( - hasattr(self, "experts_up_proj_scales") - and None not in self.experts_up_proj_scales - and None not in self.experts_gate_proj_scales - and None not in self.w2_scale_list - ): - gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape - up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape - assert gate_in_dim == up_in_dim - dtype = self.experts_gate_proj_scales[0].dtype - total_expert_num = self.ep_n_routed_experts + self.redundancy_expert_num - - w1_scale = torch.empty( - (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" - ) - - for i_experts in range(self.ep_n_routed_experts + self.redundancy_expert_num): - w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] - w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] - - inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] - w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( - len(self.w2_scale_list), inter_shape, hidden_size - ) - self.w1[1] = self._cuda(w1_scale) - self.w2[1] = self._cuda(w2_scale) - delattr(self, "w2_scale_list") - delattr(self, "experts_up_proj_scales") - delattr(self, "experts_gate_proj_scales") - - def load_hf_weights(self, weights): - n_expert_ep = self.ep_n_routed_experts - - # Load bias - if self.e_score_correction_bias_name in weights: - self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name]) - - # Get weight shapes from first expert to determine intermediate size - first_expert_idx = 0 + n_expert_ep * self.global_rank_ - w1_weight_name = f"{self.weight_prefix}.{first_expert_idx}.{self.w1_weight_name}.weight" - if w1_weight_name in weights: - intermediate_size = weights[w1_weight_name].shape[0] - self.intermediate_size = intermediate_size - - # Re-create weights with correct size if needed - if self.w1[0].shape[1] != intermediate_size * 2: - self._create_weight() - - # Load regular experts - for i_experts_ep in range(n_expert_ep): - i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - self._copy_expert_weights(i_experts_ep, i_experts, weights) - - # Load redundant experts - for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - self._copy_expert_weights(n_expert_ep + i, redundant_expert_id, weights) - - if self.quantized_weight: - self._load_weight_scale_direct(weights) - - def _copy_expert_weights(self, target_idx, expert_id, weights): - """Copy a single expert's weights to pre-allocated GPU memory""" - w1_weight = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.weight" - w2_weight = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.weight" - w3_weight = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.weight" - - intermediate_size = self.intermediate_size - - if w1_weight in weights and w3_weight in weights: - # Combine gate and up projections into w1 - gate_weight = weights[w1_weight] # [intermediate_size, hidden_size] - up_weight = weights[w3_weight] # [intermediate_size, hidden_size] - - # Copy to pre-allocated memory - if not self.quantized_weight and self.quant_method is not None: - # Quantized path - combined_cpu = torch.empty((intermediate_size * 2, self.hidden_size), dtype=gate_weight.dtype) - combined_cpu[:intermediate_size, :] = gate_weight - combined_cpu[intermediate_size:, :] = up_weight - quantized_pack = self.quant_method.quantize(combined_cpu) - self.w1[0][target_idx].copy_(quantized_pack.weight.view(intermediate_size * 2, self.hidden_size)) - if quantized_pack.weight_scale is not None: - self.w1[1][target_idx].copy_( - quantized_pack.weight_scale.view(intermediate_size * 2, self.hidden_size) - ) - else: - # Regular path - self.w1[0][target_idx, :intermediate_size, :].copy_(gate_weight) - self.w1[0][target_idx, intermediate_size:, :].copy_(up_weight) - - if w2_weight in weights: - # Copy w2 (down projection) - w2_weight_tensor = weights[w2_weight] # [hidden_size, intermediate_size] - already the correct shape - if not self.quantized_weight and self.quant_method is not None: - quantized_pack = self.quant_method.quantize(w2_weight_tensor) - self.w2[0][target_idx].copy_(quantized_pack.weight) - if quantized_pack.weight_scale is not None: - self.w2[1][target_idx].copy_(quantized_pack.weight_scale) - else: - self.w2[0][target_idx].copy_(w2_weight_tensor) - - def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None: - n_expert_ep = self.ep_n_routed_experts - for i_experts_ep in range(n_expert_ep): - i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[i_experts_ep] = weights[w1_scale] - if w3_scale in weights: - self.experts_up_proj_scales[i_experts_ep] = weights[w3_scale] - - if w2_scale in weights: - self.w2_scale_list[i_experts_ep] = weights[w2_scale] - - # Load scale parameters for redundant experts - for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - i_experts = redundant_expert_id - w1_scale = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.{self.weight_scale_suffix}" - if w1_scale in weights: - self.experts_gate_proj_scales[n_expert_ep + i] = weights[w1_scale] - if w3_scale in weights: - self.experts_up_proj_scales[n_expert_ep + i] = weights[w3_scale] - if w2_scale in weights: - self.w2_scale_list[n_expert_ep + i] = weights[w2_scale] - - def _load_weight_scale_direct(self, weights: Dict[str, torch.Tensor]) -> None: - """Load weight scales directly to pre-allocated GPU memory""" - n_expert_ep = self.ep_n_routed_experts - - # Load regular expert scales - for i_experts_ep in range(n_expert_ep): - i_experts = i_experts_ep + n_expert_ep * self.global_rank_ - self._copy_expert_scales(i_experts_ep, i_experts, weights) - - # Load redundant expert scales - for i, redundant_expert_id in enumerate(self.redundancy_expert_ids): - self._copy_expert_scales(n_expert_ep + i, redundant_expert_id, weights) - - def _copy_expert_scales(self, target_idx, expert_id, weights): - """Copy a single expert's weight scales to pre-allocated GPU memory""" - w1_scale = f"{self.weight_prefix}.{expert_id}.{self.w1_weight_name}.{self.weight_scale_suffix}" - w2_scale = f"{self.weight_prefix}.{expert_id}.{self.w2_weight_name}.{self.weight_scale_suffix}" - w3_scale = f"{self.weight_prefix}.{expert_id}.{self.w3_weight_name}.{self.weight_scale_suffix}" - - intermediate_size = self.intermediate_size - - if w1_scale in weights and w3_scale in weights: - # Combine gate and up projection scales into w1 scale - gate_scale = weights[w1_scale] # [intermediate_size, hidden_size] - up_scale = weights[w3_scale] # [intermediate_size, hidden_size] - - # Copy to pre-allocated memory - self.w1[1][target_idx, :intermediate_size, :].copy_(gate_scale) - self.w1[1][target_idx, intermediate_size:, :].copy_(up_scale) - - if w2_scale in weights: - # Copy w2 scale (down projection) - w2_scale_tensor = weights[w2_scale] # [hidden_size, intermediate_size] - self.w2[1][target_idx].copy_(w2_scale_tensor) - - def _cuda(self, cpu_tensor): - if self.quantized_weight: - return cpu_tensor.contiguous().cuda(self.device_id_) - return cpu_tensor.contiguous().to(self.data_type_).cuda(self.device_id_) From 4b27aacc21148d5388f00e3f3518a7d7768b54eb Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 23 Jan 2026 10:24:38 +0000 Subject: [PATCH 42/43] add redundancy assert --- .../layer_weights/meta_weights/fused_moe/fused_moe_weight.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index ced1e9267b..6a1bd0ca45 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -86,6 +86,8 @@ def _init_redundancy_expert_params(self): self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") + # TODO: find out the reason of failure of deepep when redundancy_expert_num is 1. + assert self.redundancy_expert_num != 1, "redundancy_expert_num can not be 1 for some unknown hang of deepep." def _init_parallel_params(self): self.local_n_routed_experts = self.n_routed_experts + self.num_fused_shared_experts From 351387dd67f123bcee104f6e9796a12787044f50 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 23 Jan 2026 10:52:03 +0000 Subject: [PATCH 43/43] fix mm weight with bias --- .../basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 1133e4d6ab..56aa322b49 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -116,7 +116,7 @@ def load_hf_weights(self, weights): def _create_weight(self): self.bias = None if self.bias_names is not None: - self.bias = torch.empty(self.cusum_out_dims[-1], dtype=self.data_type_).cuda(get_current_device_id()) + self.bias = torch.empty(sum(self.out_dims), dtype=self.data_type_).cuda(get_current_device_id()) self.bias._load_ok = [False] * len(self.bias_names) self.mm_param: WeightPack = self.quant_method.create_weight( in_dim=self.in_dim, out_dim=sum(self.out_dims), dtype=self.data_type_, device_id=get_current_device_id()