Skip to content

[Feature] Offload optimizer states to CPU to reduce NPU memory with minimal performance impact#1524

Open
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:swap_optimizer
Open

[Feature] Offload optimizer states to CPU to reduce NPU memory with minimal performance impact#1524
tina-wen wants to merge 1 commit intoInternLM:mainfrom
tina-wen:swap_optimizer

Conversation

@tina-wen
Copy link

@tina-wen tina-wen commented Mar 3, 2026

Description

This PR adds CPU offloading for optimizer states to reduce NPU memory usage. Optimizer states stay in host memory and are transferred to device only during optimizer.step() via h2d/d2h communications.

Changes

  • Offload optimizer states to CPU memory
  • Transfer to device only during optimizer.step()
  • Resolve conflicts with DCP.save and RL offload_optimizer
  • Trade memory efficiency for performance

Testing

Verified with:

  • Memory reduction tests
  • DCP checkpoint compatibility
  • RL optimization workflows

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Please

self.optimizer = optimizer
self.swap_optimizer_times = swap_optimizer_times
if SwapOptimizerOperate.swap_to_device_stream is None:
SwapOptimizerOperate.swap_to_device_stream = torch.npu.Stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use get_torch_device_module() to get DEVICE_MODULE to replace torch.npu

self.optimizer.swap_numel = swap_numel

swap_memory = swap_num * 8 / 1024 / 1024
print('[Rank {}] swap optimizer param num: {}, param size: {}MB\n'.format(torch.npu.current_device(), swap_num, swap_memory), end='')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using logger defined in xtuner

cls.swap_to_host_events_map[param] = None

@classmethod
def swap_all_to_device(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the swap_to_device_stream wait for the main cuda stream to avoid for the memory peak cause by backward computation and swap_all_to_device

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swap_to_device_stream is not used?

cls.swap_to_device_events_map[param] = torch.npu.current_stream().record_event()

@classmethod
def wait_swap_to_device_event(cls, param):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused function?

[group['step']], amsgrad=amsgrad, lr=group['lr'], beta1=beta1, beta2=beta2, weight_decay=group['weight_decay'],
eps=group['eps'], maximize=group['maximize'])

# it maybe removed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why swap_all_to_host is not called here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants