Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

sam-pi
Copy link

@sam-pi sam-pi commented Jan 17, 2025

Context

Adapted from the great work in #1966

What is the purpose of this PR? Is it to

  • add a new feature

Please link to any issues this PR addresses: relates to #2082

Changelog

What are the changes made in this PR?

  • Adds full DPO distributed training configs and recipes, adapting from the lora DPO training
  • Includes integration tests
  • Includes configs for llama3.1 8B and 70B models

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API

Commands and Sample Outputs

Full DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/full_dpo
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
ref_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-06
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.05
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 2000
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8B-dpo_3605
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false

Lora DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/lora_dpo
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_rank: 256
  lora_alpha: 256
  lora_dropout: 0.0
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
save_adapter_weights_only: false
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-05
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.1
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 100
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8Blora-dpo_3603
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
Screenshot 2025-01-16 at 12 39 23 PM

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@sam-pi
Copy link
Author

sam-pi commented Jan 17, 2025

@joecummings Please take a look and let me know if you have feedback!

_,
_,
) = self.concatenated_forward(self._ref_model, batch)
assert not reference_chosen_log_probs.requires_grad
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need these here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to remove these assertions!


# synchronize before training begins
torch.distributed.barrier()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll also want to make sure dropout is disabled in both models.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! Though if you have a neater way of doing it feel free : )

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added in that same approach

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 20, 2025

Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points.

Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)?

I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc.

ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
)
self._ref_model = self._setup_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to consider a more bespoke way to setup the reference model here. For example, we don't need to do things like enable_activation_offloading=self._enable_activation_offloading, as this should always be set to False for the reference model, which won't be storing intermediate activations under no_grad.

I've left a more general comment about this - the gist is that we don't necessarily wish to apply the same parallelization and memory optimization strategies to both the reference and policy models.

Copy link
Author

@sam-pi sam-pi Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good point, do you think then I can just update to something like the following, or should there be more generalization?

        self._ref_model = self._setup_reference_model(
            cfg_model=cfg.model,
            custom_sharded_layers=cfg.get("custom_sharded_layers", None),
            fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
            reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
            model_state_dict=self.load_ref_states(cfg.ref_checkpointer),
        )

Also, I found that for 70B it was necessary to set reshard_after_forward=True or else the entire model seemed to be loaded per node/rank.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I found that for 70B it was necessary to set reshard_after_forward=True or else the entire model seemed to be loaded per node/rank.

Was this for the reference or the policy model?

Ah yes good point, do you think then I can just update to something like the following, or should there be more generalization?

I think having a seperate method is good! I'll try get round to testing out some different configs this weekend. Would you mind if I pushed changes to your branch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is for the reference model. Ok thanks, and no problem for you to push changes to my branch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a first cut of this, just re-using the _setup_model method wrapped in _setup_reference_model.

@sam-pi
Copy link
Author

sam-pi commented Jan 21, 2025

@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations).
For Llama3.1-70B-Instruct, I was able to run using 2 nodes with 8x H100 GPUs (I think this is just 2x the HW requirements for running a single non-quantized 70B).

@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you mentioned you trained on 2 nodes it'd be good to add the command you used here.

Seperately, I'm going to try see if I can find a config that can train on a single node with reasonable speeds.

Copy link
Author

@sam-pi sam-pi Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into running this on 1 node and I couldn't find a way to get it to fit - if you do please feel free to update. Otherwise, maybe it's not worth including this 70B_full_dpo.yaml in the PR since technically I only got this working with some custom scripts using sbatch and torchrun with --nnodes 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants