-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full DPO Distributed #2275
base: main
Are you sure you want to change the base?
Full DPO Distributed #2275
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@joecummings Please take a look and let me know if you have feedback! |
recipes/full_dpo_distributed.py
Outdated
_, | ||
_, | ||
) = self.concatenated_forward(self._ref_model, batch) | ||
assert not reference_chosen_log_probs.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we need these here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated to remove these assertions!
|
||
# synchronize before training begins | ||
torch.distributed.barrier() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll also want to make sure dropout is disabled in both models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok will do, following this pattern in the PPO recipe? https://github.com/pytorch/torchtune/blob/main/recipes/ppo_full_finetune_single_device.py#L569
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly! Though if you have a neater way of doing it feel free : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added in that same approach
Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points. Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)? I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc. |
recipes/full_dpo_distributed.py
Outdated
ac_mode=cfg.get("ac_mode", None), | ||
ac_option=cfg.get("ac_option", None), | ||
) | ||
self._ref_model = self._setup_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to consider a more bespoke way to setup the reference model here. For example, we don't need to do things like enable_activation_offloading=self._enable_activation_offloading
, as this should always be set to False
for the reference model, which won't be storing intermediate activations under no_grad
.
I've left a more general comment about this - the gist is that we don't necessarily wish to apply the same parallelization and memory optimization strategies to both the reference and policy models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes good point, do you think then I can just update to something like the following, or should there be more generalization?
self._ref_model = self._setup_reference_model(
cfg_model=cfg.model,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
model_state_dict=self.load_ref_states(cfg.ref_checkpointer),
)
Also, I found that for 70B it was necessary to set reshard_after_forward=True
or else the entire model seemed to be loaded per node/rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I found that for 70B it was necessary to set reshard_after_forward=True or else the entire model seemed to be loaded per node/rank.
Was this for the reference or the policy model?
Ah yes good point, do you think then I can just update to something like the following, or should there be more generalization?
I think having a seperate method is good! I'll try get round to testing out some different configs this weekend. Would you mind if I pushed changes to your branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is for the reference model. Ok thanks, and no problem for you to push changes to my branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a first cut of this, just re-using the _setup_model method wrapped in _setup_reference_model.
@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations). |
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you mentioned you trained on 2 nodes it'd be good to add the command you used here.
Seperately, I'm going to try see if I can find a config that can train on a single node with reasonable speeds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into running this on 1 node and I couldn't find a way to get it to fit - if you do please feel free to update. Otherwise, maybe it's not worth including this 70B_full_dpo.yaml in the PR since technically I only got this working with some custom scripts using sbatch and torchrun with --nnodes 2.
Context
Adapted from the great work in #1966
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses: relates to #2082
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
Commands and Sample Outputs
Full DPO Config
Lora DPO Config