Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`.
- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`).
- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`.
- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files <changed_files>`. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`.
- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead.
- New files require license header:
```
Expand Down
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

- All commits MUST have a `Signed-off-by` line (use `--signoff`). Get the name and email from `git config user.name` / `git config user.email`.
- Formatting: yapf (column_limit=119, `.style.yapf`) + flake8 (`.flake8`).
- Always verify changed files pass pre-commit checks before committing. Config: `.pre-commit-config.yaml`.
- Always verify changed files pass pre-commit checks before committing: `pre-commit run --files <changed_files>`. Only check modified files, not the entire codebase. Config: `.pre-commit-config.yaml`.
- `check-torchdist` hook: NEVER directly import torch's distributed module. Use `import deepspeed.comm as dist` instead.
- New files require license header:
```
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

## Latest News

* [2026/03] [Our SuperOffload work received an Honorable Mention for the ASPLOS 2026 Best Paper Award](https://dl.acm.org/doi/10.1145/3760250.3762217)

* [2025/12] [DeepSpeed Core API updates: PyTorch-style backward and low-precision master states](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/core_api_update/README.md)

* [2025/11] [DeepSpeed ZeRO++ powers large-scale distillation training of LLMs for Recommendation Systems at LinkedIn](https://aclanthology.org/2025.emnlp-industry.119/)
Expand Down
6 changes: 2 additions & 4 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,12 @@ def rollback_subgroup(self, sub_group_id: int, closure=None):
f"CPUAdam param is on {param.device} and must be 'cpu', "
f"make sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config.")

# Decrement step count
subgroup_state['step'] -= 1

# Extract hyperparameters
beta1, beta2 = group['betas']

self.ds_opt_adam.adam_rollback(self.opt_id, subgroup_state['step'], group['lr'], beta1, beta2,
group['eps'], group['weight_decay'], group['bias_correction'],
param.data, param.grad.data, subgroup_state['exp_avg'],
subgroup_state['exp_avg_sq'])

subgroup_state['step'] -= 1
return loss
14 changes: 11 additions & 3 deletions deepspeed/ops/transformer/inference/triton/matmul_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@ def is_nfs_path(path):
# Normalize the path to get the absolute path
path = os.path.abspath(path)

# Walk up to the nearest existing ancestor so 'df' does not fail
# when the target directory has not been created yet (see #7642).
while not os.path.exists(path):
parent = os.path.dirname(path)
if parent == path:
break
path = parent

# Use the 'df' command to find the file system type for the given path
try:
output = subprocess.check_output(['df', '-T', path], encoding='utf-8')
except subprocess.CalledProcessError:
return False # Command failed
output = subprocess.check_output(['df', '-T', path], encoding='utf-8', stderr=subprocess.DEVNULL)
except (subprocess.CalledProcessError, FileNotFoundError):
return False # Command failed or 'df' not available

# Process the output of 'df -T' to check for 'nfs' in the filesystem type column
lines = output.strip().split('\n')
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,9 @@ def zero_legacy_stage1(self):
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters

def zero_save_muon_momentum_buffer_in_memory(self):
return self._config.zero_config.save_muon_momentum_buffer_in_memory

def tensor_parallel_config(self):
return self._config.tensor_parallel_config

Expand Down Expand Up @@ -1733,7 +1736,6 @@ def _configure_basic_optimizer(self, model_parameters):
optimizer = MuSGD(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == MUON_OPTIMIZER:
zero_stage = self.zero_optimization_stage()
assert zero_stage <= ZeroStageEnum.gradients, "Muon optimizer is not yet compatible with ZeRO Stage 3"
if not all([hasattr(p, 'use_muon') for p in model_parameters]):
msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \
"please set by `param.use_muon = True / False` for all params"
Expand Down Expand Up @@ -2045,6 +2047,7 @@ def _configure_zero_optimizer(self, optimizer):
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
enable_sanity_checks=self.is_sanity_checks_enabled(),
cpuadam_cores_perc=self.cpuadam_cores_perc(),
save_muon_momentum_buffer_in_memory=self.zero_save_muon_momentum_buffer_in_memory(),
)

else:
Expand Down
Loading