Skip to content

Vit sep#1234

Open
blueswhen wants to merge 4 commits intomainfrom
vit_sep
Open

Vit sep#1234
blueswhen wants to merge 4 commits intomainfrom
vit_sep

Conversation

@blueswhen
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the multimodal processing architecture by introducing the capability to run the Vision Transformer (ViT) as a remote, decoupled service. This change enables greater flexibility in deploying and scaling multimodal models by separating the visual inference workload. Complementing this, a Redis-backed caching system is integrated to efficiently manage image embeddings across the distributed environment. Additionally, attention mechanisms are optimized through the adoption of a more concise Flash Attention function, and multimodal parameter handling is refined to support the new distributed and caching paradigms.

Highlights

  • Remote Vision Transformer (ViT) Support: Introduced a new 'remote ViT' mode, allowing the visual processing component to run as a separate service, enhancing deployment flexibility and scalability for multimodal models. This includes new command-line arguments, a 'visual' run mode, and a VITConnectionManager for handling remote connections.
  • Redis-backed Embed Cache: Implemented MemoryCacheWithRedis to manage image embeddings efficiently across distributed components. This new cache leverages Redis for reference counting and LRU eviction, crucial for the remote ViT setup, and introduces related configuration arguments.
  • Flash Attention Integration: Replaced verbose torch.ops.sgl_kernel.fwd.default calls with the more streamlined flash_attn_varlen_func in Fa3VitAttBackend and flashattention_nopad.py, simplifying attention computation.
  • Multimodal Parameter Handling Refinements: Enhanced the handling of image patch numbers and UUIDs for multimodal inputs in lightllm/models/internvl/model.py and lightllm/server/multimodal_params.py, supporting the new caching and remote processing logic.
  • API and Server Enhancements: Added new API endpoints, specifically /get_image_embedding, and extended server run modes to include visual and visual_only to fully support the new remote ViT functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lightllm/common/basemodel/attention_vit/fa3/fp.py
    • Imported flash_attn_varlen_func.
    • Replaced torch.ops.sgl_kernel.fwd.default with flash_attn_varlen_func for attention computation.
  • lightllm/models/internvl/model.py
    • Added img.patch_num assignment in init_imageitem_extral_params.
    • Refactored get_image_token_length to use a new get_image_patch method.
  • lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
    • Added imports for rpyc, socket, get_shm_name_embed, load_tensor_afs, and get_env_start_args.
    • Initialized self.args and self.cache_client in init for remote ViT.
    • Added _copy_loaded_embed_to_cache method.
    • Modified context_forward to include unique_uids and handle remote ViT embedding loading and release.
  • lightllm/models/vit/triton_kernel/flashattention_nopad.py
    • Replaced torch.ops.sgl_kernel.fwd.default with flash_attn_varlen_func for attention computation.
  • lightllm/server/api_cli.py
    • Added visual and visual_only to run_mode choices.
    • Added new arguments: --image_embed_dir, --enable_remote_vit, --remote_vit_port, --redis_port, --redis_evict_fraction, --start_redis.
  • lightllm/server/api_http.py
    • Imported lightllm_get_image_embedding.
    • Added visual to run_mode check in healthcheck.
    • Added new /get_image_embedding POST endpoint.
    • Added a check for g_objs.httpserver_manager is None in startup_event.
  • lightllm/server/api_lightllm.py
    • Imported JSONResponse.
    • Added lightllm_get_image_embedding asynchronous function to handle image embedding requests.
  • lightllm/server/api_server.py
    • Imported visual_start.
    • Added visual_start(args) call for visual run mode.
  • lightllm/server/api_start.py
    • Imported is_multimodal_mode and start_redis_service.
    • Added _ensure_remote_vit_embed_dir and _prepare_remote_vit_embed_dir functions.
    • Modified setup_signal_handlers to conditionally log http_server_process.pid.
    • Modified normal_or_p_d_start to accept only_prepare argument and include visual and visual_only in run mode checks.
    • Added args.enable_multimodal and _prepare_remote_vit_embed_dir(args).
    • Adjusted port allocation logic for visual_nccl_ports.
    • Reordered and conditionalized start_audio_process and start_visual_process based on disable_vision, disable_audio, and enable_remote_vit.
    • Removed --preload argument from gunicorn commands for pd_master_start and config_server_start.
    • Added visual_start function to handle visual-only server startup, including port allocation, process management for cache and visual processes, and signal handling.
    • Added start_redis_service(args) call in config_server_start.
  • lightllm/server/config_server/api_http.py
    • Imported VIT_Obj.
    • Added registered_visual_server_objs and registered_visual_server_obj_lock.
    • Added /visual_register websocket endpoint for visual server registration.
    • Added /registered_visual_objects GET endpoint to retrieve registered visual objects.
  • lightllm/server/core/objs/io_objs/group_req.py
    • Modified shm_req_indexes to be None if shm_req_objs is None.
  • lightllm/server/embed_cache/impl/memory_cache_with_redis.py
    • Added new file MemoryCacheWithRedis, implementing Redis-backed reference counting and LRU eviction for image embeddings.
  • lightllm/server/embed_cache/impl/naive_memory_cache.py
    • Added _records attribute pointing to _id_to_records.
    • Refactored _add_ref and _del_ref to use a new _update_record_ref method.
    • Added _update_record_ref_by_id method.
    • Changed uid_int assignment in alloc from uuid.uuid1().int to md5sum.
    • Added _del_ref call for add_ref_m_list in alloc failure path.
    • Modified release to use _update_record_ref_by_id.
    • Modified get_items_embed to accept embeding_only argument.
  • lightllm/server/embed_cache/manager.py
    • Imported MemoryCacheWithRedis.
    • Modified exposed_get_items_embed to accept embeding_only argument.
    • Added get_cache_manager function to return either MemoryCacheWithRedis or InMemoryCache based on enable_remote_vit or run_mode.
    • Updated start_cache_manager to use get_cache_manager.
  • lightllm/server/embed_cache/utils.py
    • Added imports for os, time, torch, redis, numpy, typing, io, pathlib.
    • Added _get_afs_path, tensor2bytes, bytes2tensor, save_tensor_afs, load_tensor_afs for AFS (shared file system) operations.
    • Added create_afs, read_afs, free_afs for AFS-based shared memory.
    • Added get_shm_name_embed.
    • Added EmbedRefCountRedis class, implementing Redis-backed reference counting with LRU eviction for MD5 hashes, including Lua scripts for atomic operations and AFS file deletion on eviction.
  • lightllm/server/httpserver/manager.py
    • Initialized VITConnectionManager if enable_multimodal is true, replacing direct ZMQ sockets for visual/audio.
    • Modified _alloc_resource to use uuids instead of md5sums.
    • Added logic in _alloc_resource to handle remote ViT, avoiding local data caching and calling get_items_embed to prevent LRU eviction.
    • Modified _alloc_multimodal_resources to generate uuid from md5sum and patch_num (for images) or extra_params (for audio).
    • Added get_image_embeding async method to handle image embedding requests.
    • Modified transfer_to_next_module to use vit_manager.send_to_vit for multimodal requests and conditionally send to router based on enable_multimodal or enable_remote_vit.
    • Added vit_manager.vit_handle_loop() to handle_loop if enable_multimodal.
  • lightllm/server/multimodal_params.py
    • Added afs_embed attribute to AudioItem.
    • Modified AudioItem.read to return _preload_data directly without freeing.
    • Added patch_num attribute to ImageItem.
    • Modified ImageItem.read to return _preload_data directly.
    • Added ImageItem.free method to clear _preload_data and _data.
    • Added MultimodalParams.free to free resources for all images and audios.
    • Added MultimodalParams.get_all_uuids to retrieve all UUIDs.
  • lightllm/server/router/model_infer/infer_batch.py
    • Modified init_cpu_embed_cache_client to accept init_shm_data argument.
  • lightllm/server/router/model_infer/mode_backend/base_backend.py
    • Modified g_infer_context.init_cpu_embed_cache_client call to pass init_shm_data=self.args.enable_remote_vit.
  • lightllm/server/visualserver/manager.py
    • Imported create_shm, get_shm_name_data.
    • Added args, visual_only, remote_vit attributes to init.
    • Conditionalized ZMQ socket binding for zmq_recv_socket based on remote_vit and visual_only.
    • Modified flush_ready to use get_items_embed(..., True) for ready_image check.
    • Added _recv_reqs method to handle incoming requests, including logic for remote ViT to allocate cache and create SHM data if needed.
    • Modified loop_for_netio_req to use _recv_reqs.
    • Adjusted visual_recv_max_count update logic.
    • Added loop_for_fwd_visual_only for visual-only inference.
    • Added create_forward_loop to conditionally start loop_for_fwd_visual_only or loop_for_fwd and register_loop.
    • Modified start_visual_process to use create_forward_loop and handle visualserver cleanup more robustly.
  • lightllm/server/visualserver/model_infer/model_rpc.py
    • Updated imports for transformers.configuration_utils and lightllm.server.embed_cache.utils.
    • Added args, image_embed_dir, remote_vit attributes to VisualModelRpcServer in exposed_init_model.
    • Passed remote_vit to model initialization.
    • Conditionalized CpuEmbedCacheClient initialization based on remote_vit.
    • Modified exposed_encode to handle remote ViT: skip if tp_rank_id != 0, use get_items_embed(..., True), save embeddings to AFS using save_tensor_afs if remote_vit, otherwise copy to CPU embed cache.
    • Modified VisualModelRpcClient.init to accept conn and initialize _bg for background serving.
    • Removed redundant return statements in VisualModelRpcClient methods.
    • Modified start_model_process to return VisualModelRpcClient(con, ...) instead of con.root.
  • lightllm/server/visualserver/register_loop.py
    • Added new file register_loop, implementing a loop for remote visual servers to register themselves with the config server via WebSocket, sending heartbeats.
  • lightllm/server/visualserver/vit_connect.py
    • Added new file vit_connect, defining VIT_Obj dataclass and VITConnectionManager class for managing connections to local or remote ViT instances, including registration, load balancing, and embedding readiness checks.
  • lightllm/utils/redis_utils.py
    • Added new file redis_utils, implementing start_redis_service to launch and manage a Redis server process.
  • lightllm/utils/start_utils.py
    • Added is_multimodal_mode function to determine if a model is multimodal based on its configuration.
  • requirements.txt
    • Added xformers==0.0.33.post2.
Activity
  • No specific activity (comments, reviews, progress) was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant changes to support a separated Vision Transformer (ViT) service, including remote ViT inference, a Redis-backed embedding cache, and new server run modes. The changes are extensive and touch many parts of the server infrastructure. While the overall architecture seems well-thought-out, I've identified several critical issues, including a bug that will cause a ValueError due to incorrect hash string formatting, insecure use of pickle for network data, and overly permissive file/directory permissions. I've also found some medium-severity issues like typos, redundant code, and leftover debug statements. Please address the critical and high-severity issues before merging.

Comment on lines +176 to +177
md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num)
uuid = int(md5sum, 16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The md5sum string is constructed by concatenating a hex digest with an underscore and a number, which results in a string that is not a valid hexadecimal number (e.g., '..._6'). Calling int(md5sum, 16) on this string will raise a ValueError. You should create a single valid hash from all components. For example, you could hash the concatenation of the data and the patch number.

Suggested change
md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num)
uuid = int(md5sum, 16)
md5sum = hashlib.md5(data + str(img.patch_num).encode()).hexdigest()
uuid = int(md5sum, 16)

Comment on lines +186 to +190
md5sum = "{}_{}".format(
hashlib.md5(data).hexdigest(),
hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(),
)
uuid = int(md5sum, 16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The md5sum string is constructed by concatenating two hex digests with an underscore, which results in a string that is not a valid hexadecimal number. Calling int(md5sum, 16) on this string will raise a ValueError. You should create a single valid hash from all components. For example, you could hash the concatenation of the data and the extra parameters.

                    md5sum = hashlib.md5(data + pickle.dumps(audio.extra_params, protocol=4)).hexdigest()
                    uuid = int(md5sum, 16)

"""
异步获取VIT实例信息
"""
uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The URI for an HTTP GET request should start with http://, not ws://. Using ws:// with an HTTP client will likely fail.

Suggested change
uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects"
uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects"

Comment on lines +31 to +32
os.makedirs(image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(image_embed_dir, 0o777)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using 0o777 permissions is highly permissive and can be a security risk, as it gives read, write, and execute permissions to everyone on the system. Consider using more restrictive permissions, such as 0o755 for directories, depending on the access requirements.

Suggested change
os.makedirs(image_embed_dir, mode=0o777, exist_ok=True)
os.chmod(image_embed_dir, 0o777)
os.makedirs(image_embed_dir, mode=0o755, exist_ok=True)
os.chmod(image_embed_dir, 0o755)

await websocket.accept()
client_ip, client_port = websocket.client
logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}")
registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Deserializing data with pickle from a network source is a security risk and can lead to arbitrary code execution. Even for internal services, it's safer to use a more secure serialization format like JSON if possible. If you must use pickle, ensure the communication channel is secure and authenticated.

with open(tmp_path, "wb") as f:
torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4)
os.replace(tmp_path, target_path)
os.chmod(target_path, 0o777)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using 0o777 permissions is highly permissive and can be a security risk, as it gives read, write, and execute permissions to everyone on the system. Consider using more restrictive permissions, such as 0o664 for files, depending on the access requirements.

Suggested change
os.chmod(target_path, 0o777)
os.chmod(target_path, 0o664)

f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, target_path)
os.chmod(target_path, 0o777)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using 0o777 permissions is highly permissive and can be a security risk, as it gives read, write, and execute permissions to everyone on the system. Consider using more restrictive permissions, such as 0o664 for files, depending on the access requirements.

Suggested change
os.chmod(target_path, 0o777)
os.chmod(target_path, 0o664)

Comment on lines +81 to +83
for _, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
release_ids.append(img["uuid"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop over infer_state.multimodal_params is redundant. You can collect release_ids in the first loop over the same data structure (lines 59-67) to improve efficiency and code clarity. This avoids iterating over the same data twice. You can add release_ids.append(img["uuid"]) to the first loop when self.args.enable_remote_vit is true.

raise e
return

async def get_image_embeding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the function name get_image_embeding. It should be get_image_embedding. Please also update the call site in lightllm/server/api_lightllm.py.

    async def get_image_embedding(

req.multimodal_params.free()

try:
print(instance, flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging and should be removed from production code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant