Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lightllm/models/gemma3/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def context_forward(self, input_ids, infer_state, layer_weight):

from lightllm.server.router.model_infer.infer_batch import g_infer_context

cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor

cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client
cpu_embed_cache_tensor = (
torch.empty((0, 0, hidden_size), dtype=dtype, device=device)
if cpu_embed_cache_client is None
else cpu_embed_cache_client.cpu_embed_cache_tensor
)
Comment on lines +48 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this logic correctly handles the case where cpu_embed_cache_client is None, it is duplicated in qwen3_vl and qwen_vl pre-layer inference files. To improve maintainability and avoid repeating code, consider extracting this logic into a shared helper method in the base class LlamaMultimodalPreLayerInfer.

assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
f"Dimension mismatch: text weight dimension is {hidden_size}, "
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def context_forward(

from lightllm.server.router.model_infer.infer_batch import g_infer_context

cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client
cpu_embed_cache_tensor = (
torch.empty((0, 0, hidden_size), dtype=dtype, device=device)
if cpu_embed_cache_client is None
else cpu_embed_cache_client.cpu_embed_cache_tensor
)
Comment on lines +37 to +42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to safely retrieve cpu_embed_cache_tensor is also present in the gemma3 and qwen_vl model files. To adhere to the DRY (Don't Repeat Yourself) principle, this logic should be centralized. A helper method in the LlamaMultimodalPreLayerInfer base class would be an ideal place for it.

infer_state.cpu_embed_cache_tensor = cpu_embed_cache_tensor

assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
Expand Down
7 changes: 6 additions & 1 deletion lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei

from lightllm.server.router.model_infer.infer_batch import g_infer_context

cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
cpu_embed_cache_client = g_infer_context.cpu_embed_cache_client
cpu_embed_cache_tensor = (
torch.empty((0, 0, hidden_size), dtype=dtype, device=device)
if cpu_embed_cache_client is None
else cpu_embed_cache_client.cpu_embed_cache_tensor
)
Comment on lines +51 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for safely initializing cpu_embed_cache_tensor is duplicated in the subclasses Gemma3PreLayerInfer and Qwen3VLMultimodalPreLayerInfer. Since this is the base class, you could define a protected helper method here (e.g., _get_cpu_embed_cache_tensor) to encapsulate this logic. The subclasses can then call this method, which would eliminate the code duplication and make future changes easier.


assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
f"Dimension mismatch: text weight dimension is {hidden_size}, "
Expand Down
Loading