Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions xtuner/v1/patch/__init__.py
Comment thread
tina-wen marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from . import torch_shape_env_simplify_pt28
from .torch_dcp_planner import patch_default_save_plan
from .torch_dcp_planner import patch_dcp_save_state_dict, patch_dcp_save_with_cache_storage, patch_default_save_plan


__all__ = ["patch_default_save_plan", "torch_shape_env_simplify_pt28"]
__all__ = [
"patch_default_save_plan",
"torch_shape_env_simplify_pt28",
"patch_dcp_save_state_dict",
"patch_dcp_save_with_cache_storage",
]
123 changes: 123 additions & 0 deletions xtuner/v1/patch/torch_dcp_planner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
import inspect
import warnings
from pathlib import Path
from typing import Optional

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.distributed.checkpoint.default_planner as torch_default_runner
from torch.distributed.checkpoint import state_dict_saver
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
from torch.distributed.checkpoint.utils import _DistWrapper


def fake_validate_global_plan(*args, **kwargs):
Expand All @@ -7,3 +22,111 @@ def fake_validate_global_plan(*args, **kwargs):

def patch_default_save_plan():
torch_default_runner._validate_global_plan = fake_validate_global_plan


def _xtuner_save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")

distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None

global_metadata = None

ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
ckpt_kwargs["process_group"] = distW.group

@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation"
" to include this parameter."
)
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
else:
planner.set_up_planner(
state_dict=state_dict,
storage_meta=storage_meta,
is_coordinator=distW.is_coordinator,
)
storage_writer.set_up_storage_writer(distW.is_coordinator)

local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan

@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
nonlocal global_metadata

assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans

central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)

@_dcp_method_logger(**ckpt_kwargs)
def write_data():
assert planner is not None
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)

all_writes.wait()
return all_writes.value()

@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint(all_results):
assert global_metadata is not None
storage_writer.finish(metadata=global_metadata, results=all_results)
# return global_metadata
return Metadata(state_dict_metadata={}) # This is a patch to avoid broadcast overhead.

return distW.all_reduce("write", write_data, finish_checkpoint)


def patch_dcp_save_state_dict():
Comment thread
tina-wen marked this conversation as resolved.
if hasattr(state_dict_saver, "_save_state_dict"):
original = getattr(state_dict_saver, "_save_state_dict")
if callable(original):
state_dict_saver._save_state_dict = _xtuner_save_state_dict


def patch_dcp_save_with_cache_storage():
original_dcp_save = dcp.save

def dcp_save_with_cache_storage(state_dict, **kwargs):
checkpoint_id = kwargs.get("checkpoint_id", None)
assert checkpoint_id is not None, "checkpoint_id is required for caching mechanism."
checkpoint_id = checkpoint_id if isinstance(checkpoint_id, Path) else Path(checkpoint_id)
planner = kwargs.get("planner", None)
storage_writer = kwargs.get("storage_writer", None)

if storage_writer is None and planner is None:
from xtuner.v1.patch.xtuner_cache_planner import XtunerCacheSavePlanner
from xtuner.v1.patch.xtuner_storage import XtunerCacheWriter

planner = XtunerCacheSavePlanner(enable_plan_caching=True, cache_key_prefix=checkpoint_id.stem)
storage_writer = XtunerCacheWriter(
checkpoint_id, enable_write_result_caching=True, cache_key_prefix=checkpoint_id.stem
)
kwargs["planner"] = planner
kwargs["storage_writer"] = storage_writer

return original_dcp_save(state_dict, **kwargs)

dcp.save = dcp_save_with_cache_storage
99 changes: 99 additions & 0 deletions xtuner/v1/patch/xtuner_cache_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

from torch.distributed.checkpoint import DefaultSavePlanner, Metadata, SavePlan, SavePlanner
from torch.distributed.checkpoint.planner_helpers import ( # type: ignore[attr-defined]
_compare_save_plans,
_merge_delta_local_plans,
)


# copy from torch 2.8.0 planner_helpers.py
def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
"""Check if any delta plan is usable, indicating the plan has changed.

Args:
delta_plans (List[SavePlan]): A list of delta plans to check.
Returns:
True if any delta plan is usable, False otherwise.
"""
return any(delta_plan and delta_plan.usable for delta_plan in delta_plans) # type: ignore[attr-defined]


class XtunerCacheSavePlanner(DefaultSavePlanner):
# Metadata for the global checkpoint plan as computed by `create_global_plan` API.
# Cached on the coordinator rank.
_cached_metadata: dict[str, Metadata] = {}

def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
dedup_save_to_lowest_rank: bool = False,
enable_plan_caching: bool = False,
cache_key_prefix: str = "",
) -> None:
super().__init__(
flatten_state_dict,
flatten_sharded_tensors,
dedup_replicated_tensors,
dedup_save_to_lowest_rank,
enable_plan_caching, # type: ignore[call-arg]
)
self._cached_plans_key: str = cache_key_prefix + self.__class__.__name__

def _create_global_plan_with_caching(
self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], list[SavePlan], Metadata]:
if hasattr(SavePlanner, "_cached_metadata"):
# adaptor for torch >= 2.8.0
return super()._create_global_plan_with_caching(all_plans) # type: ignore[misc]

# ONLY cache ``_cached_metadata`` in XtunerCacheSavePlanner
global_plan_delta: list[SavePlan] = []

if self._cached_plans_key not in SavePlanner._cached_all_plans: # type: ignore[attr-defined]
# Case 1: If the plans are not cached, the cache will be hydrated with the
# all_plans, global_plans (Deduped), and metadata.

# Cache the original all_plans
SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans # type: ignore[attr-defined]
global_plan, metadata = self._create_global_plan(all_plans) # type: ignore[attr-defined]
# Cache the deduped and validated global_plan
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan # type: ignore[attr-defined]
# Cache the metadata
XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata
# If plans are not cached, global_plan delta will be the same as global plan.
return global_plan, global_plan, metadata

# Case 2: Plans are cached
if not _contains_usable_plan(all_plans):
# Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
# Global plan delta will be empty plans to avoid the collective overhead.
# We can reuse the deduped global plan and metadata from the cache directly.
global_plan_delta = [SavePlan([], usable=False)] * len(all_plans) # type: ignore[call-arg]
global_plan = SavePlanner._cached_global_plan[self._cached_plans_key] # type: ignore[attr-defined]
metadata = XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key]
else:
# Case 2.2: Plans are cached but the local plans have changed.
# We will merge the changed local plans with the cached local plans.
# Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
# Global plan delta will be created by comparing the new global plan with the cached global plan.
# Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
merged_plans = _merge_delta_local_plans(SavePlanner._cached_all_plans[self._cached_plans_key], all_plans) # type: ignore[attr-defined]
# Cache the updated local plans
SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans # type: ignore[attr-defined]
global_plan, metadata = self._create_global_plan(merged_plans) # type: ignore[attr-defined]

if self._cached_plans_key in self._cached_global_plan: # type: ignore[attr-defined]
for cached_plan, new_plan in zip(SavePlanner._cached_global_plan[self._cached_plans_key], global_plan): # type: ignore[attr-defined]
if _compare_save_plans(cached_plan, new_plan):
global_plan_delta.append(SavePlan([], usable=False)) # type: ignore[call-arg]
else:
global_plan_delta.append(new_plan)

# Cache the new global plan and the metadata
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan # type: ignore[attr-defined]
XtunerCacheSavePlanner._cached_metadata[self._cached_plans_key] = metadata

return global_plan_delta, global_plan, metadata
138 changes: 138 additions & 0 deletions xtuner/v1/patch/xtuner_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
from collections.abc import Sequence
from typing import Any, Optional, Union

import torch
from packaging import version
from torch.distributed.checkpoint import FileSystemWriter, Metadata, SavePlan, SavePlanner
from torch.distributed.checkpoint._extension import (
StreamTransformExtension,
)
from torch.distributed.checkpoint.storage import (
WriteResult,
)
from torch.futures import Future


# PyTorch 2.7+ introduced _extensions parameter for FileSystemWriter
_TORCH_DCP_FSWRITER_HAS_EXTENSIONS = version.parse(torch.__version__) >= version.parse("2.7.0")


def _compare_write_results(write_results: list[WriteResult], other_write_results: list[WriteResult]) -> bool:
"""Compare two lists of WriteResults for equality.

Args:
write_results: First list of WriteResults to compare.
other_write_results: Second list of WriteResults to compare.

Returns:
True if both lists have the same length and all elements are equal,
False otherwise.
"""

# Both the plans should have the same number of items
if len(write_results) != len(other_write_results):
return False

# Both the plans should have the same write items.
for write_item, other_write_item in zip(write_results, other_write_results):
# Write item type should be same
if write_item != other_write_item:
return False

return True


def _contains_new_write_results(results: list[list[WriteResult]]) -> bool:
return any(delta_result for delta_result in results)


class XtunerCacheWriter(FileSystemWriter):
# Save write results for the current rank as computed by `write_data` API
# Cached on the local rank.
_cache_write_results: dict[str, list[WriteResult]] = {}

# Collection of all the write results from all the ranks.
# This is the ``results`` input to the `finish` API.
# Cached on the coordinator rank.
_cached_all_write_results: dict[str, list[list[WriteResult]]] = {}

def __init__(
self,
path: Union[str, os.PathLike],
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
cache_staged_state_dict: bool = False,
overwrite: bool = True,
_extensions: Optional[Sequence[StreamTransformExtension]] = None,
enable_write_result_caching: bool = False,
cache_key_prefix: str = "",
) -> None:
# Build kwargs conditionally to support both PyTorch 2.6 and 2.7+
kwargs: dict[str, Any] = dict()
if _TORCH_DCP_FSWRITER_HAS_EXTENSIONS:
kwargs["_extensions"] = _extensions
super().__init__(
path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
cache_staged_state_dict=cache_staged_state_dict,
overwrite=overwrite,
**kwargs,
)
self._enable_write_result_caching = enable_write_result_caching
self._cached_write_results_key = cache_key_prefix + self.__class__.__name__

def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[list[WriteResult]]:
all_writes_fut = super().write_data(plan, planner)

if self._enable_write_result_caching:
all_writes_fut = self._get_write_future_with_caching(all_writes_fut)
return all_writes_fut

def _get_write_future_with_caching(self, all_writes_fut):
new_fut: Future[list[WriteResult]] = Future()
all_writes_fut.wait()

if self._cached_write_results_key not in XtunerCacheWriter._cache_write_results:
# Case 1: If the write results are not cached,.............
XtunerCacheWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value()
new_fut.set_result(all_writes_fut.value())
elif _compare_write_results(
all_writes_fut.value(), XtunerCacheWriter._cache_write_results[self._cached_write_results_key]
):
# Case 2: equal
new_fut.set_result([])
else:
# Case 3: not equal
XtunerCacheWriter._cache_write_results[self._cached_write_results_key] = all_writes_fut.value()
new_fut.set_result(all_writes_fut.value())

return new_fut

def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
if self._enable_write_result_caching:
results = self._get_results_from_caching(results)

super().finish(metadata, results)

def _get_results_from_caching(self, results: list[list[WriteResult]]):
if self._cached_write_results_key not in XtunerCacheWriter._cached_all_write_results:
# Case 1:
XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key] = results
elif not _contains_new_write_results(results):
# Case 2: no new
results = XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key]
else:
# Case 3: not equal TODO: merge
XtunerCacheWriter._cached_all_write_results[self._cached_write_results_key] = results

return results
Loading
Loading