-
Notifications
You must be signed in to change notification settings - Fork 417
[Optimization] Incremental checkpoint save for dcp on torch 2.7.x (ARM CPU optimization) #1525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,10 @@ | ||
| from . import torch_shape_env_simplify_pt28 | ||
| from .torch_dcp_planner import patch_default_save_plan | ||
| from .torch_dcp_planner import patch_dcp_save_state_dict, patch_dcp_save_with_cache_storage, patch_default_save_plan | ||
|
|
||
|
|
||
| __all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28"] | ||
| __all__ = [ | ||
| "patch_default_save_plan", | ||
| "torch_shape_env_simplify_pt28", | ||
| "patch_dcp_save_state_dict", | ||
| "patch_dcp_save_with_cache_storage", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| from typing import Optional | ||
|
|
||
| from torch.distributed.checkpoint import DefaultSavePlanner, Metadata, SavePlan, SavePlanner | ||
| from torch.distributed.checkpoint.planner_helpers import ( # type: ignore[attr-defined] | ||
| _compare_save_plans, | ||
| _merge_delta_local_plans, | ||
| ) | ||
|
|
||
|
|
||
| # copy from torch 2.8.0 planner_helpers.py | ||
| def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool: | ||
| """Check if any delta plan is usable, indicating the plan has changed. | ||
|
|
||
| Args: | ||
| delta_plans (List[SavePlan]): A list of delta plans to check. | ||
| Returns: | ||
| True if any delta plan is usable, False otherwise. | ||
| """ | ||
| return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) # type: ignore[attr-defined] | ||
|
|
||
|
|
||
| class XtunerCacheSavePlanner(DefaultSavePlanner): | ||
| # Metadata for the global checkpoint plan as computed by `create_global_plan` API. | ||
| # Cached on the coordinator rank. | ||
| _cached_metadata: dict[str, Metadata] = {} | ||
|
|
||
| def __init__( | ||
| self, | ||
| flatten_state_dict: bool = True, | ||
| flatten_sharded_tensors: bool = True, | ||
| dedup_replicated_tensors: Optional[bool] = None, | ||
| dedup_save_to_lowest_rank: bool = False, | ||
| enable_plan_caching: bool = False, | ||
| cache_key_prefix: str = "", | ||
| ) -> None: | ||
| super().__init__( | ||
| flatten_state_dict, | ||
| flatten_sharded_tensors, | ||
| dedup_replicated_tensors, | ||
| dedup_save_to_lowest_rank, | ||
| enable_plan_caching, # type: ignore[call-arg] | ||
| ) | ||
| self._cached_plans_key: str = cache_key_prefix + self.__class__.__name__ | ||
|
|
||
| def _create_global_plan_with_caching( | ||
| self, all_plans: list[SavePlan] | ||
| ) -> tuple[list[SavePlan], list[SavePlan], Metadata]: | ||
| if hasattr(SavePlanner, "_cached_metadata"): | ||
| # adaptor for torch >= 2.8.0 | ||
| return super()._create_global_plan_with_caching(all_plans) # type: ignore[misc] | ||
|
|
||
| # ONLY cache ``_cached_metadata`` in XtunerCacheSavePlanner | ||
| global_plan_delta: list[SavePlan] = [] | ||
|
|
||
| if self._cached_plans_key not in SavePlanner._cached_all_plans: # type: ignore[attr-defined] | ||
| # Case 1: If the plans are not cached, the cache will be hydrated with the | ||
| # all_plans, global_plans (Deduped), and metadata. | ||
|
|
||
| # Cache the original all_plans | ||
| SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans # type: ignore[attr-defined] | ||
| global_plan, metadata = self._create_global_plan(all_plans) # type: ignore[attr-defined] | ||
| # Cache the deduped and validated global_plan | ||
| SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan # type: ignore[attr-defined] | ||
| # Cache the metadata | ||
| XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata | ||
| # If plans are not cached, global_plan delta will be the same as global plan. | ||
| return global_plan, global_plan, metadata | ||
|
|
||
| # Case 2: Plans are cached | ||
| if not _contains_usable_plan(all_plans): | ||
| # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans). | ||
| # Global plan delta will be empty plans to avoid the collective overhead. | ||
| # We can reuse the deduped global plan and metadata from the cache directly. | ||
| global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) # type: ignore[call-arg] | ||
| global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] # type: ignore[attr-defined] | ||
| metadata = XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] | ||
| else: | ||
| # Case 2.2: Plans are cached but the local plans have changed. | ||
| # We will merge the changed local plans with the cached local plans. | ||
| # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached. | ||
| # Global plan delta will be created by comparing the new global plan with the cached global plan. | ||
| # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead. | ||
| merged_plans = _merge_delta_local_plans(SavePlanner._cached_all_plans[self._cached_plans_key], all_plans) # type: ignore[attr-defined] | ||
| # Cache the updated local plans | ||
| SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans # type: ignore[attr-defined] | ||
| global_plan, metadata = self._create_global_plan(merged_plans) # type: ignore[attr-defined] | ||
|
|
||
| if self._cached_plans_key in self._cached_global_plan: # type: ignore[attr-defined] | ||
| for cached_plan, new_plan in zip(SavePlanner._cached_global_plan[self._cached_plans_key], global_plan): # type: ignore[attr-defined] | ||
| if _compare_save_plans(cached_plan, new_plan): | ||
| global_plan_delta.append(SavePlan([], usable=False)) # type: ignore[call-arg] | ||
| else: | ||
| global_plan_delta.append(new_plan) | ||
|
|
||
| # Cache the new global plan and the metadata | ||
| SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan # type: ignore[attr-defined] | ||
| XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata | ||
|
|
||
| return global_plan_delta, global_plan, metadata |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| import os | ||
| from collections.abc import Sequence | ||
| from typing import Any, Optional, Union | ||
|
|
||
| import torch | ||
| from packaging import version | ||
| from torch.distributed.checkpoint import FileSystemWriter, Metadata, SavePlan, SavePlanner | ||
| from torch.distributed.checkpoint._extension import ( | ||
| StreamTransformExtension, | ||
| ) | ||
| from torch.distributed.checkpoint.storage import ( | ||
| WriteResult, | ||
| ) | ||
| from torch.futures import Future | ||
|
|
||
|
|
||
| # PyTorch 2.7+ introduced _extensions parameter for FileSystemWriter | ||
| _TORCH_DCP_FSWRITER_HAS_EXTENSIONS = version.parse(torch.__version__) >= version.parse("2.7.0") | ||
|
|
||
|
|
||
| def _compare_write_results(write_results: list[WriteResult], other_write_results: list[WriteResult]) -> bool: | ||
| """Compare two lists of WriteResults for equality. | ||
|
|
||
| Args: | ||
| write_results: First list of WriteResults to compare. | ||
| other_write_results: Second list of WriteResults to compare. | ||
|
|
||
| Returns: | ||
| True if both lists have the same length and all elements are equal, | ||
| False otherwise. | ||
| """ | ||
|
|
||
| # Both the plans should have the same number of items | ||
| if len(write_results) != len(other_write_results): | ||
| return False | ||
|
|
||
| # Both the plans should have the same write items. | ||
| for write_item, other_write_item in zip(write_results, other_write_results): | ||
| # Write item type should be same | ||
| if write_item != other_write_item: | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| def _contains_new_write_results(results: list[list[WriteResult]]) -> bool: | ||
| return any(delta_result for delta_result in results) | ||
|
|
||
|
|
||
| class XtunerCacheWriter(FileSystemWriter): | ||
| # Save write results for the current rank as computed by `write_data` API | ||
| # Cached on the local rank. | ||
| _cache_write_results: dict[str, list[WriteResult]] = {} | ||
|
|
||
| # Collection of all the write results from all the ranks. | ||
| # This is the ``results`` input to the `finish` API. | ||
| # Cached on the coordinator rank. | ||
| _cached_all_write_results: dict[str, list[list[WriteResult]]] = {} | ||
|
|
||
| def __init__( | ||
| self, | ||
| path: Union[str, os.PathLike], | ||
| single_file_per_rank: bool = True, | ||
| sync_files: bool = True, | ||
| thread_count: int = 1, | ||
| per_thread_copy_ahead: int = 10_000_000, | ||
| cache_staged_state_dict: bool = False, | ||
| overwrite: bool = True, | ||
| _extensions: Optional[Sequence[StreamTransformExtension]] = None, | ||
| enable_write_result_caching: bool = False, | ||
| cache_key_prefix: str = "", | ||
| ) -> None: | ||
| # Build kwargs conditionally to support both PyTorch 2.6 and 2.7+ | ||
| kwargs: dict[str, Any] = dict() | ||
| if _TORCH_DCP_FSWRITER_HAS_EXTENSIONS: | ||
| kwargs["_extensions"] = _extensions | ||
| super().__init__( | ||
| path, | ||
| single_file_per_rank=single_file_per_rank, | ||
| sync_files=sync_files, | ||
| thread_count=thread_count, | ||
| per_thread_copy_ahead=per_thread_copy_ahead, | ||
| cache_staged_state_dict=cache_staged_state_dict, | ||
| overwrite=overwrite, | ||
| **kwargs, | ||
| ) | ||
| self._enable_write_result_caching = enable_write_result_caching | ||
| self._cached_write_results_key = cache_key_prefix + self.__class__.__name__ | ||
|
|
||
| def write_data( | ||
| self, | ||
| plan: SavePlan, | ||
| planner: SavePlanner, | ||
| ) -> Future[list[WriteResult]]: | ||
| all_writes_fut = super().write_data(plan, planner) | ||
|
|
||
| if self._enable_write_result_caching: | ||
| all_writes_fut = self._get_write_future_with_caching(all_writes_fut) | ||
| return all_writes_fut | ||
|
|
||
| def _get_write_future_with_caching(self, all_writes_fut): | ||
| new_fut: Future[list[WriteResult]] = Future() | ||
| all_writes_fut.wait() | ||
|
|
||
| if self._cached_write_results_key not in XtunerCacheWriter._cache_write_results: | ||
| # Case 1: If the write results are not cached,............. | ||
| XtunerCacheWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value() | ||
| new_fut.set_result(all_writes_fut.value()) | ||
| elif _compare_write_results( | ||
| all_writes_fut.value(), XtunerCacheWriter._cache_write_results[self._cached_write_results_key] | ||
| ): | ||
| # Case 2: equal | ||
| new_fut.set_result([]) | ||
| else: | ||
| # Case 3: not equal | ||
| XtunerCacheWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value() | ||
| new_fut.set_result(all_writes_fut.value()) | ||
|
|
||
| return new_fut | ||
|
|
||
| def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: | ||
| if self._enable_write_result_caching: | ||
| results = self._get_results_from_caching(results) | ||
|
|
||
| super().finish(metadata, results) | ||
|
|
||
| def _get_results_from_caching(self, results: list[list[WriteResult]]): | ||
| if self._cached_write_results_key not in XtunerCacheWriter._cached_all_write_results: | ||
| # Case 1: | ||
| XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key] = results | ||
| elif not _contains_new_write_results(results): | ||
| # Case 2: no new | ||
| results = XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key] | ||
| else: | ||
| # Case 3: not equal TODO: merge | ||
| XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key] = results | ||
|
|
||
| return results |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.