Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: RuntimeError for UCP large DP #6918

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

fix: RuntimeError for UCP large DP #6918

wants to merge 3 commits into from

Conversation

saforem2
Copy link
Collaborator

We encountered a strange bug when attempting to convert checkpoints (created with DP=768) to universal format.

An overview of the bug as well as a detailed description of the proposed fix is written up in:

argonne-lcf / Megatron-DeepSpeed / ALCF / notes / universal_checkpoint_bug.md

@loadams loadams requested a review from lekurile December 30, 2024 17:54
@saforem2
Copy link
Collaborator Author

@loadams thanks for the formatting fix!

Also just wanted to say no rush on this, I spoke briefly with @minjiazhang before the holidays about this issue and mentioned I would write up a more complete description of what I was seeing.

Maybe some minor thoughts:

I think the change in deepspeed / checkpoint / deepspeed_checkpoint.py,
e.g. passing the strip_tensor_paddings argument through to the self.zero_checkpoint.get_state_for_rank call (shown below):

-    def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict:
+    def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index, strip_tensor_paddings: bool = True):
         return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index,
                                                        tp_index=tp_index,
                                                        dp_index=dp_index,
                                                        keys_to_ignore=[PARAM_SHAPES])
                                                        keys_to_ignore=[PARAM_SHAPES],
+                                                       strip_tensor_paddings=strip_tensor_paddings)

✅ is OK since this just passes the argument through.

However, I'm a bit less sure about this change:

 sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index,
                                                  tp_index=tp_index,
                                                  dp_index=dp_index,
+                                                 strip_tensor_paddings=False)

since I'm not completely clear on how the internals of this
_strip_tensor_paddings() function work.

For our purposes, setting this to False, and thereby skipping the:

if strip_tensor_paddings:
    self._strip_tensor_paddings(sd)

block in the get_state_for_rank call seems to work, though I'm not really sure why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants