diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index 3543786f69..ff8d899e8e 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -45,8 +45,12 @@ def context_forward(self, input_ids, infer_state, layer_weight): from lightllm.server.router.model_infer.infer_batch import g_infer_context - cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor - + cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client + cpu_embed_cache_tensor = ( + torch.empty((0, 0, hidden_size), dtype=dtype, device=device) + if cpu_embed_cache_client is None + else cpu_embed_cache_client.cpu_embed_cache_tensor + ) assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index c24166e13d..6be827ac0a 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -34,7 +34,12 @@ def context_forward( from lightllm.server.router.model_infer.infer_batch import g_infer_context - cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor + cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client + cpu_embed_cache_tensor = ( + torch.empty((0, 0, hidden_size), dtype=dtype, device=device) + if cpu_embed_cache_client is None + else cpu_embed_cache_client.cpu_embed_cache_tensor + ) infer_state.cpu_embed_cache_tensor = cpu_embed_cache_tensor assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 939843a3eb..9b9fe2569c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -48,7 +48,12 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei from lightllm.server.router.model_infer.infer_batch import g_infer_context - cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor + cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client + cpu_embed_cache_tensor = ( + torch.empty((0, 0, hidden_size), dtype=dtype, device=device) + if cpu_embed_cache_client is None + else cpu_embed_cache_client.cpu_embed_cache_tensor + ) assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, "