diff --git a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py index b65b18dc8..8630b6abc 100644 --- a/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py +++ b/lightllm/common/basemodel/triton_kernel/linear_att_cpu_cache_copy.py @@ -42,7 +42,7 @@ def _copy_kv_buffer_to_cpu_cache( head_scale_size, BLOCK: tl.constexpr, ): - block_index_start = tl.program_id(0) + split_index_start = tl.program_id(0) grid_num = tl.num_programs(0) # 将 所有stride 切成 tl.int64 cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64) @@ -62,7 +62,7 @@ def _copy_kv_buffer_to_cpu_cache( cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64) cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64) - for block_index in range(block_index_start, page_num, grid_num): + for block_index in range(page_num): cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) run_flag = 1 if cpu_page_index == -1: @@ -76,7 +76,7 @@ def _copy_kv_buffer_to_cpu_cache( head_flag = 0 mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index - for i in range(tl.cdiv(gpu_full_att_tail_dim, BLOCK) * run_flag * head_flag): + for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK) * run_flag * head_flag, grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < gpu_full_att_tail_dim per_token_size = gpu_full_att_tail_dim // big_page_token_num @@ -103,7 +103,7 @@ def _copy_kv_buffer_to_cpu_cache( big_page_idx = tl.load(big_page_buffer_ids + block_index) - for i in range(tl.cdiv(cpu_kv_conv_tail_dim, BLOCK) * run_flag): + for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK) * run_flag, grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < cpu_kv_conv_tail_dim cpu_kv_conv_data = tl.load( @@ -119,7 +119,7 @@ def _copy_kv_buffer_to_cpu_cache( ) tl.store(dest_cpu_cache_conv_ptr, cpu_kv_conv_data, mask=mask) - for i in range(tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK) * run_flag): + for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK) * run_flag, grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < cpu_kv_ssm_tail_dim @@ -149,7 +149,7 @@ def copy_kv_buffer_to_cpu_cache( tp_world_size: int, big_page_token_num: int, linear_config: LinearAttCacheConfig, - grid_num: int = 16, + grid_num: int = 12, ): assert len(page_indexes) == len(page_readies) == len(big_page_buffer_ids) assert len(mem_indexes) % len(page_indexes) == 0 @@ -172,15 +172,25 @@ def copy_kv_buffer_to_cpu_cache( else: cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1) - cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1) - cpu_cache_ssm = cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1) + cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64) + # 保证可以以128bit对齐的方式进行数据的load 和 store。 + assert cpu_cache_full_att.shape[-1] % 2 == 0 + + cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + cpu_cache_ssm = ( + cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + ) gpu_kv_full_att_state = gpu_kv_full_att_state.view( gpu_kv_full_att_state.shape[0], gpu_kv_full_att_state.shape[1], -1 - ).view(dtype=torch.uint8) + ).view(dtype=torch.uint64) + gpu_kv_full_att_state = gpu_kv_full_att_state.permute(1, 0, 2) # [s, layer_num, xxdim] - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint8) - cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint8) + # 保证可以以128bit对齐的方式进行数据的load 和 store。 + assert gpu_kv_full_att_state.shape[-1] % 2 == 0 + + cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64) + cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64) gpu_full_att_tail_dim = gpu_kv_full_att_state.shape[-1] * gpu_kv_full_att_state.shape[-2] * big_page_token_num cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1] @@ -195,6 +205,7 @@ def copy_kv_buffer_to_cpu_cache( assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] + assert gpu_kv_full_att_state.stride(2) == 1 assert ( gpu_full_att_tail_dim % big_page_token_num == 0 and (gpu_full_att_tail_dim // big_page_token_num) % full_att_layer_num == 0 @@ -278,7 +289,7 @@ def _copy_cpu_cache_to_kv_buffer( head_scale_size, BLOCK: tl.constexpr, ): - block_index_start = tl.program_id(0) + split_index_start = tl.program_id(0) grid_num = tl.num_programs(0) # 将 所有stride 切成 tl.int64 cpu_cache_full_att_stride_p = tl.cast(cpu_cache_full_att_stride_p, tl.int64) @@ -298,11 +309,11 @@ def _copy_cpu_cache_to_kv_buffer( cpu_kv_ssm_stride_s = tl.cast(cpu_kv_ssm_stride_s, tl.int64) cpu_kv_ssm_stride_d = tl.cast(cpu_kv_ssm_stride_d, tl.int64) - for block_index in range(block_index_start, page_num, grid_num): + for block_index in range(page_num): cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) mem_start_ptr = mem_indexes_ptr + big_page_token_num * block_index - for i in range(tl.cdiv(gpu_full_att_tail_dim, BLOCK)): + for i in range(split_index_start, tl.cdiv(gpu_full_att_tail_dim, BLOCK), grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < gpu_full_att_tail_dim per_token_size = gpu_full_att_tail_dim // big_page_token_num @@ -318,20 +329,26 @@ def _copy_cpu_cache_to_kv_buffer( + (tp_rank // head_scale_size) * cpu_cache_full_att_stride_h + gpu_start_i ) - cpu_full_att_data = tl.load(src_cpu_cache_full_att_ptr, mask=mask & (mem_index != -1), other=0) + # 标记主要是为了让编译器可以以128bit的方式生成指令进行拉取 + mem_mask = mem_index != -1 + mem_mask = tl.max_constancy(mem_mask, [2]) + dim_index = tl.max_contiguous(dim_index, [2]) + mem_index = tl.max_constancy(mem_index, [2]) + + cpu_full_att_data = tl.load(src_cpu_cache_full_att_ptr, mask=mask & mem_mask, other=0) tl.store( gpu_kv_full_att_state + mem_index * gpu_kv_full_att_stride_s + layer_index * gpu_kv_full_att_stride_l - + dim_index * gpu_kv_full_att_stride_d, + + dim_index, cpu_full_att_data, - mask=mask & (mem_index != -1), + mask=mask & mem_mask, ) big_page_idx = tl.load(big_page_buffer_ids + block_index) - for i in range(tl.cdiv(cpu_kv_conv_tail_dim, BLOCK)): + for i in range(split_index_start, tl.cdiv(cpu_kv_conv_tail_dim, BLOCK), grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < cpu_kv_conv_tail_dim @@ -349,7 +366,7 @@ def _copy_cpu_cache_to_kv_buffer( mask=mask, ) - for i in range(tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK)): + for i in range(split_index_start, tl.cdiv(cpu_kv_ssm_tail_dim, BLOCK), grid_num): gpu_start_i = i * BLOCK + tl.arange(0, BLOCK) mask = gpu_start_i < cpu_kv_ssm_tail_dim @@ -379,8 +396,9 @@ def copy_cpu_cache_to_kv_buffer( tp_world_size: int, big_page_token_num: int, linear_config: LinearAttCacheConfig, - grid_num: int = 16, + grid_num: int = 12, ): + assert len(mem_indexes) % len(page_indexes) == 0 BLOCK = 4096 @@ -400,15 +418,25 @@ def copy_cpu_cache_to_kv_buffer( else: cpu_cache_full_att = cpu_cache_tensor[:, 0:a].view(cpu_page_num, linear_config.full_att_all_num_kv_heads, -1) - cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1) - cpu_cache_ssm = cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1) + cpu_cache_full_att = cpu_cache_full_att.view(dtype=torch.uint64) + # 保证可以以128bit对齐的方式进行数据的load 和 store。 + assert cpu_cache_full_att.shape[-1] % 2 == 0 + + cpu_cache_conv = cpu_cache_tensor[:, a : (a + b)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + cpu_cache_ssm = ( + cpu_cache_tensor[:, (a + b) : (a + b + c)].view(cpu_page_num, tp_world_size, -1).view(dtype=torch.uint64) + ) gpu_full_att_kv_state = gpu_full_att_kv_state.view( gpu_full_att_kv_state.shape[0], gpu_full_att_kv_state.shape[1], -1 - ).view(dtype=torch.uint8) + ).view(dtype=torch.uint64) gpu_full_att_kv_state = gpu_full_att_kv_state.permute(1, 0, 2) # [s, layer_num, xxdim] - cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint8) - cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint8) + + # 保证可以以128bit对齐的方式进行数据的load 和 store。 + assert gpu_full_att_kv_state.shape[-1] % 2 == 0 + + cpu_kv_conv_state = cpu_kv_conv_state.view(cpu_kv_conv_state.shape[0], -1).view(dtype=torch.uint64) + cpu_kv_ssm_state = cpu_kv_ssm_state.view(cpu_kv_ssm_state.shape[0], -1).view(dtype=torch.uint64) gpu_full_att_tail_dim = gpu_full_att_kv_state.shape[-1] * gpu_full_att_kv_state.shape[-2] * big_page_token_num cpu_kv_conv_tail_dim = cpu_kv_conv_state.shape[-1] @@ -418,6 +446,7 @@ def copy_cpu_cache_to_kv_buffer( assert gpu_full_att_tail_dim == cpu_cache_full_att.shape[-1] assert cpu_cache_conv.shape[-1] == cpu_kv_conv_state.shape[-1] assert cpu_cache_ssm.shape[-1] == cpu_kv_ssm_state.shape[-1] + assert gpu_full_att_kv_state.stride(2) == 1 assert (tp_rank // head_scale_size) < linear_config.full_att_all_num_kv_heads diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 5c923b5e6..b88946d5d 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -103,7 +103,8 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): if self.need_sync_compute_stream(): # TODO fa3 现在必须使用同步模式, 未来需要移除 - g_infer_context.get_overlap_stream().synchronize() + torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream()) + # g_infer_context.get_overlap_stream().synchronize() mem_manager = self.backend.model.mem_manager req_manager = self.backend.model.req_manager diff --git a/test/benchmark/service/benchmark_multiturn.py b/test/benchmark/service/benchmark_multiturn.py index c1c87b0f5..7387237f4 100644 --- a/test/benchmark/service/benchmark_multiturn.py +++ b/test/benchmark/service/benchmark_multiturn.py @@ -34,15 +34,16 @@ """ import argparse -import asyncio import json import os import random +import threading import time +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Tuple, Union -import aiohttp import numpy as np +import requests from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -102,12 +103,12 @@ def append_turn_input( return new_prompt, new_len -async def stream_one_turn( - session: aiohttp.ClientSession, +def stream_one_turn( url: str, model_name: str, prompt: str, max_new_tokens: int, + request_timeout_s: int, ) -> Optional[Dict]: """Send one streaming completion request, return per-turn stats: { @@ -139,16 +140,24 @@ async def stream_one_turn( completion_tokens = 0 cached_tokens = 0 - try: - async with session.post(url, headers=headers, json=payload) as response: - if response.status != 200: - err = await response.text() - print(f"\n[turn failed] status={response.status} body={err[:200]}") - return None - - async for raw in response.content: + with requests.Session() as req_session: + req_session.trust_env = False + with req_session.post( + url, + headers=headers, + json=payload, + stream=True, + timeout=(10, request_timeout_s), + ) as response: + if response.status_code != 200: + err = response.text + raise RuntimeError(f"stream_one_turn failed: status={response.status_code}, body={err[:200]}") + + for raw in response.iter_lines(): + if not raw: + continue line = raw.strip() - if not line or not line.startswith(b"data:"): + if not line.startswith(b"data:"): continue data_str = line[len(b"data:") :].strip() if data_str == b"[DONE]": @@ -183,12 +192,9 @@ async def stream_one_turn( last_token_time = now if text_piece: generated_text_parts.append(text_piece) - except Exception as e: - print(f"\n[turn exception] {e}") - return None if first_token_time is None: - return None + raise RuntimeError("stream_one_turn failed: no token received from stream") return { "ttft": first_token_time - start_time, @@ -200,10 +206,9 @@ async def stream_one_turn( } -async def run_session( +def run_session( session_id: int, tokenizer, - session: aiohttp.ClientSession, url: str, model_name: str, start_input_len: int, @@ -214,7 +219,9 @@ async def run_session( output_len: int, max_turns: int, base_seed: int, + request_timeout_s: int, progress_state: Dict, + progress_lock: threading.Lock, ) -> List[Dict]: """Run a single multi-turn dialogue session. Returns a list of per-turn stat dicts (same schema as stream_one_turn output).""" @@ -223,34 +230,43 @@ async def run_session( per_turn: List[Dict] = [] turn_idx = 0 - while turn_idx < max_turns and prompt_len < max_input_len: - turn_output_len = rng.randint(min_output_len, output_len) - result = await stream_one_turn(session, url, model_name, prompt, turn_output_len) - if result is None: - break - per_turn.append(result) - progress_state["finished_turns"] += 1 - print( - f"\rconc={progress_state['concurrency']} " - f"finished_turns={progress_state['finished_turns']} " - f"active_sessions={progress_state['active_sessions']}", - end="", - ) - turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) - prompt, prompt_len = append_turn_input( - tokenizer, - prompt, - result["generated_text"], - turn_input_len, - rng, - ) - turn_idx += 1 - - progress_state["active_sessions"] -= 1 + try: + while turn_idx < max_turns and prompt_len < max_input_len: + turn_output_len = rng.randint(min_output_len, output_len) + result = stream_one_turn( + url=url, + model_name=model_name, + prompt=prompt, + max_new_tokens=turn_output_len, + request_timeout_s=request_timeout_s, + ) + if result is None: + break + per_turn.append(result) + with progress_lock: + progress_state["finished_turns"] += 1 + print( + f"\rconc={progress_state['concurrency']} " + f"finished_turns={progress_state['finished_turns']} " + f"active_sessions={progress_state['active_sessions']}", + end="", + ) + turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment) + prompt, prompt_len = append_turn_input( + tokenizer, + prompt, + result["generated_text"], + turn_input_len, + rng, + ) + turn_idx += 1 + finally: + with progress_lock: + progress_state["active_sessions"] -= 1 return per_turn -async def run_concurrency_level( +def run_concurrency_level( concurrency: int, tokenizer, url: str, @@ -266,38 +282,39 @@ async def run_concurrency_level( request_timeout_s: int, ) -> Dict: """Run one concurrency level. Returns the aggregated stats dict.""" - timeout = aiohttp.ClientTimeout(total=request_timeout_s) - connector = aiohttp.TCPConnector(limit=max(concurrency * 2, 32)) progress_state = { "concurrency": concurrency, "finished_turns": 0, "active_sessions": concurrency, } + progress_lock = threading.Lock() wall_start = time.time() - async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: - tasks = [ - asyncio.create_task( - run_session( - sid, - tokenizer, - session, - url, - model_name, - start_input_len, - max_input_len, - min_turn_input_increment, - turn_input_increment, - min_output_len, - output_len, - max_turns, - base_seed, - progress_state, - ) + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = [ + executor.submit( + run_session, + sid, + tokenizer, + url, + model_name, + start_input_len, + max_input_len, + min_turn_input_increment, + turn_input_increment, + min_output_len, + output_len, + max_turns, + base_seed, + request_timeout_s, + progress_state, + progress_lock, ) for sid in range(concurrency) ] - session_results = await asyncio.gather(*tasks) + session_results: List[List[Dict]] = [] + for fut in as_completed(futures): + session_results.append(fut.result()) wall_end = time.time() wall_time = max(wall_end - wall_start, 1e-9) print() # newline after progress bar @@ -498,31 +515,24 @@ def main() -> None: print(f"max_turns : {args.max_turns}") all_summaries: List[Dict] = [] - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - for concurrency in concurrency_levels: - summary = loop.run_until_complete( - run_concurrency_level( - concurrency=concurrency, - tokenizer=tokenizer, - url=args.url, - model_name=model_name, - start_input_len=args.start_input_len, - max_input_len=args.max_input_len, - min_turn_input_increment=args.min_turn_input_increment, - turn_input_increment=args.turn_input_increment, - min_output_len=args.min_output_len, - output_len=args.output_len, - max_turns=args.max_turns, - base_seed=args.seed, - request_timeout_s=args.request_timeout_s, - ) - ) - print_summary(summary) - all_summaries.append(summary) - finally: - loop.close() + for concurrency in concurrency_levels: + summary = run_concurrency_level( + concurrency=concurrency, + tokenizer=tokenizer, + url=args.url, + model_name=model_name, + start_input_len=args.start_input_len, + max_input_len=args.max_input_len, + min_turn_input_increment=args.min_turn_input_increment, + turn_input_increment=args.turn_input_increment, + min_output_len=args.min_output_len, + output_len=args.output_len, + max_turns=args.max_turns, + base_seed=args.seed, + request_timeout_s=args.request_timeout_s, + ) + print_summary(summary) + all_summaries.append(summary) dump = { "config": { diff --git a/test/cpu_cache_kernel/test_speed.py b/test/cpu_cache_kernel/test_speed.py new file mode 100644 index 000000000..fca24b1a3 --- /dev/null +++ b/test/cpu_cache_kernel/test_speed.py @@ -0,0 +1,274 @@ +""" +Speed benchmark for copy_cpu_cache_to_kv_buffer in linear_att_cpu_cache_copy.py. + +Test configuration (matching the user's LinearAttCacheConfig): + tp_world_size=8, full_att_all_num_kv_heads=2, full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=1, full_att_head_dim=256, + num_linear_k_heads=2, num_linear_v_heads=8, + head_linear_k_dim=128, head_linear_v_dim=128, + conv_kernel_size=4, linear_layer_num=36, + conv_state_dtype=torch.bfloat16, ssm_state_dtype=torch.float32, + full_attention_interval=4, all_layer_num=48 +""" + +import os +import json +import time +import triton +import torch +from easydict import EasyDict + +# --------------------------------------------------------------------------- +# Step 0 – set up environment args BEFORE any import that calls +# get_env_start_args() / LinearAttCacheConfig.load_from_args(). +# --------------------------------------------------------------------------- +_env_args = { + "cpu_cache_token_page_size": 2048 * 8, # big_page_token_num + "linear_att_hash_page_size": 2048, + "linear_att_page_block_num": 8, # 512 * 1 == 512 + "data_type": "bfloat16", + "linear_att_ssm_data_type": "float32", + "model_dir": "/tmp/fake_model", # dummy – not used when config is built directly + "tp": 8, + "dp": 1, + "running_max_req_size": 2048, + "enable_cpu_cache": True, +} +os.environ["LIGHTLLM_START_ARGS"] = json.dumps(_env_args) + +# --------------------------------------------------------------------------- +# Step 1 – build LinearAttCacheConfig directly (avoids needing a real model dir) +# --------------------------------------------------------------------------- +from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig + +linear_config = LinearAttCacheConfig( + tp_world_size=8, + full_att_all_num_kv_heads=2, + full_att_dtype=torch.bfloat16, + full_att_num_kv_heads=1, + full_att_head_dim=256, + num_linear_k_heads=2, + num_linear_v_heads=8, + head_linear_k_dim=128, + head_linear_v_dim=128, + conv_kernel_size=4, + linear_layer_num=36, + conv_state_dtype=torch.bfloat16, + ssm_state_dtype=torch.float32, + full_attention_interval=4, + all_layer_num=48, +) +print(f"LinearAttCacheConfig:\n{linear_config}\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 2 – derive sizes from the config +# --------------------------------------------------------------------------- +big_page_token_num = _env_args["cpu_cache_token_page_size"] # 512 +full_att_layer_num = linear_config.all_layer_num // linear_config.full_attention_interval # 12 + +full_att_bytes = linear_config.get_cpu_cache_full_att_bytes() # per big page +conv_bytes = linear_config.get_cpu_cache_conv_bytes() +ssm_bytes = linear_config.get_cpu_cache_ssm_bytes() +total_bytes = full_att_bytes + conv_bytes + ssm_bytes +print( + f"Per-page bytes full_att={full_att_bytes:,} conv={conv_bytes:,} " f"ssm={ssm_bytes:,} total={total_bytes:,}", + flush=True, +) +total_bytes = linear_config.get_cpu_cache_big_page_bytes() + +# --------------------------------------------------------------------------- +# Step 3 – allocate tensors +# --------------------------------------------------------------------------- +grid_num = 8 +PAGE_NUM = 1 # number of big pages to copy per call +SEQ_LEN = 2048 * 8 # total sequence length in gpu_full_att_kv_state dim-1 +BIG_PAGE_COUNT = PAGE_NUM # big_page_buffer_ids length == page_indexes length + +# --- GPU tensors --- +mem_indexes = torch.arange(0, big_page_token_num * PAGE_NUM, dtype=torch.int64, device="cpu") +big_page_buffer_ids = torch.arange(0, BIG_PAGE_COUNT, dtype=torch.int64, device="cpu") +page_indexes = torch.arange(0, PAGE_NUM, dtype=torch.int32, device="cpu") + +gpu_full_att_kv_state = torch.empty( + ( + full_att_layer_num, + SEQ_LEN, + 2 * max(1, linear_config.full_att_num_kv_heads // linear_config.tp_world_size), + linear_config.full_att_head_dim, + ), + dtype=linear_config.full_att_dtype, + device="cuda", +) + +# --- CPU tensors --- +buffer_count = triton.cdiv(SEQ_LEN, big_page_token_num) + 2 # matches Qwen3NextMemManager + + +conv_shape = linear_config.get_conv_state_shape() +cpu_kv_conv_state = torch.empty( + (buffer_count, linear_config.linear_layer_num, *conv_shape), + dtype=linear_config.conv_state_dtype, + device="cpu", + pin_memory=True, +) + +ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim) +cpu_kv_ssm_state = torch.empty( + (buffer_count, linear_config.linear_layer_num, *ssm_shape), + dtype=linear_config.ssm_state_dtype, + device="cpu", + pin_memory=True, +) + +# conv_shape = linear_config.get_conv_state_shape() +# cpu_kv_conv_state = torch.empty( +# (buffer_count, linear_config.linear_layer_num, *conv_shape), +# dtype=linear_config.conv_state_dtype, device="cuda", +# ) + +# ssm_shape = linear_config.get_ssm_state_shape() # (num_linear_v_heads, head_linear_k_dim, head_linear_v_dim) +# cpu_kv_ssm_state = torch.empty( +# (buffer_count, linear_config.linear_layer_num, *ssm_shape), +# dtype=linear_config.ssm_state_dtype, device="cuda", +# ) + + +# cpu_cache_tensor: [page_num, 1, 1, 1, total_bytes] +cpu_cache_tensor = torch.empty( + (PAGE_NUM, 1, 1, 1, total_bytes), + dtype=torch.uint8, + device="cpu", + pin_memory=True, +) + +# Move GPU tensors to CUDA +mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) +big_page_buffer_ids_cuda = big_page_buffer_ids.cuda(non_blocking=True) +page_indexes_cuda = page_indexes.cuda(non_blocking=True) +gpu_full_att_kv_state = gpu_full_att_kv_state.cuda(non_blocking=True) + +torch.cuda.synchronize() +print("All tensors allocated and moved to GPU.\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 4 – import and warm-up the triton kernel +# --------------------------------------------------------------------------- +from lightllm.common.basemodel.triton_kernel.linear_att_cpu_cache_copy import ( + copy_cpu_cache_to_kv_buffer, +) + +print("Warming up …", flush=True) +copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, +) +torch.cuda.synchronize() +print("Warm-up done.\n", flush=True) + +# --------------------------------------------------------------------------- +# Step 5 – benchmark +# --------------------------------------------------------------------------- +WARMUP_ITERS = 10 +BENCH_ITERS = 100 + +print(f"Benchmarking ({BENCH_ITERS} iterations, {PAGE_NUM} pages / {big_page_token_num} tokens each) …", flush=True) + +# Warm-up +for _ in range(WARMUP_ITERS): + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, + ) +torch.cuda.synchronize() + +# Timed runs +times = [] +for _ in range(BENCH_ITERS): + torch.cuda.synchronize() + t0 = time.perf_counter() + copy_cpu_cache_to_kv_buffer( + mem_indexes=mem_indexes_cuda, + big_page_buffer_ids=big_page_buffer_ids_cuda, + page_indexes=page_indexes_cuda, + gpu_full_att_kv_state=gpu_full_att_kv_state, + cpu_kv_conv_state=cpu_kv_conv_state, + cpu_kv_ssm_state=cpu_kv_ssm_state, + cpu_cache_tensor=cpu_cache_tensor, + tp_rank=0, + tp_world_size=linear_config.tp_world_size, + big_page_token_num=big_page_token_num, + linear_config=linear_config, + grid_num=grid_num, + ) + torch.cuda.synchronize() + t1 = time.perf_counter() + times.append(t1 - t0) + +# --------------------------------------------------------------------------- +# Step 6 – report +# --------------------------------------------------------------------------- +import statistics + +times_ms = [t * 1e3 for t in times] +total_tokens = PAGE_NUM * big_page_token_num + +# Calculate head_scale_size (same logic as in copy_cpu_cache_to_kv_buffer) +if linear_config.full_att_all_num_kv_heads % linear_config.tp_world_size == 0: + head_scale_size = 1 +else: + head_scale_size = linear_config.tp_world_size // linear_config.full_att_all_num_kv_heads + +# Each TP rank copies: +# - full_att_bytes / head_scale_size (full attention is sharded by head_scale_size) +# - conv_bytes / tp_world_size (conv state is sharded by tp_rank) +# - ssm_bytes / tp_world_size (ssm state is sharded by tp_rank) +full_att_bytes = linear_config.get_cpu_cache_full_att_bytes() +conv_bytes = linear_config.get_cpu_cache_conv_bytes() +ssm_bytes = linear_config.get_cpu_cache_ssm_bytes() + +bytes_per_page_per_tp = ( + full_att_bytes // head_scale_size + + conv_bytes // linear_config.tp_world_size + + ssm_bytes // linear_config.tp_world_size +) +total_bytes_copied = PAGE_NUM * bytes_per_page_per_tp + +print() +print("=" * 60) +print(f" copy_cpu_cache_to_kv_buffer speed benchmark") +print("=" * 60) +print(f" Pages / call : {PAGE_NUM}") +print(f" Tokens / page : {big_page_token_num}") +print(f" Total tokens / call : {total_tokens}") +print(f" Bytes / page (total) : {total_bytes:,}") +print(f" Bytes / page (per TP) : {bytes_per_page_per_tp:,}") +print(f" Total bytes / call : {total_bytes_copied:,} ({total_bytes_copied / 1024**3:.3f} GB)") +print(f" Iterations : {BENCH_ITERS}") +print(f" Mean latency : {statistics.mean(times_ms):.3f} ms") +print(f" Median latency : {statistics.median(times_ms):.3f} ms") +print(f" Std latency : {statistics.stdev(times_ms):.3f} ms") +print(f" Min latency : {min(times_ms):.3f} ms") +print(f" Max latency : {max(times_ms):.3f} ms") +print(f" Throughput (tokens/s) : {total_tokens / statistics.mean(times_ms) * 1e3:,.0f}") +print(f" Throughput (GB/s) : {total_bytes_copied / 1024**3 / statistics.mean(times_ms) * 1e3:.3f}") +print("=" * 60)