-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds validation loss to LoRA fine tune single device #2238
base: main
Are you sure you want to change the base?
Adds validation loss to LoRA fine tune single device #2238
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2238
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit df8cd1e with merge base 27fd3a1 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @MaxFrax! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
@felipemello1 Finally I have been able to work on this. I'll make my way through the testing plan, but feel free to share any comment you might already have. |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
hey @MaxFrax , thank you! I am on PTO this week. I will get to it next week if someone doesnt do it before me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @MaxFrax, thanks for this PR! The validation loop itself looks pretty reasonable, but I think we should figure out the right way to integrate it. E.g. right now it seems like we perform validation after every training epoch inside of the train
method. Personally I would be in favor of splitting out into multiple methods to make things clearer. That will be a bit more work, but I want to make sure we expose this as clearly as possible. What about something like this?
def validate():
# Should be roughly the code you added
def train():
# Keep this mostly as it is, but add something like:
if self.global_step % self.run_val_every_n_steps == 0:
self.validate()
Then we can expose run_val_every_n_steps
via config. A couple other things to think about would be a maximum number of batches in the val loop and early stopping. I don't think we need to worry about the latter for this PR, but should make sure it's something we're able to support later on.
Also cc @joecummings @felipemello1 @calvinpelletier for anything I've missed here.
for idx, batch in enumerate(self._dataloader_val): | ||
utils.batch_to_device(batch, self._device) | ||
|
||
current_loss = self._loss_step(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we will need to toggle eval <-> train mode for the model, right? (Another reason having a separate method will probably be cleanest)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ebsmothers ! It makes sense. I will look into it. Do you have any pointers on how to do that?
Hi @ebsmothers ! I have updated the pr with the following edits, as per your recommendation:
If there's any other feedback or comment, just let me know! |
Thanks for making the changes! I will take a look at this PR later today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, although I'm getting worried about the proliferation of cfg.get
and cfg validation logic in the recipes. There's nothing inherently wrong about the cfg.get
, but it encourages the use of hidden parameters not exposed to the user. I don't have a good long term solution for this, but since we are only modifying one recipe, maybe we could at least update the lora single device configs to expose these fields so users know that it exists and we can check if the cfg field is None directly?
dataset_validation: null
run_val_every_n_steps: null
max_validation_batches: -1
I know this will affect a lot of files, so open to thoughts. We could also do this in a follow-up.
step=(curr_epoch + 1) * idx + idx, | ||
) | ||
|
||
if self.run_validation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need this check since you only call validate()
if self.run_validation
is True
) | ||
|
||
self.run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) | ||
if self.run_validation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could remove this if statement and keep all the logic below under the first if self.run_validation
check
@@ -335,6 +335,29 @@ def setup(self, cfg: DictConfig) -> None: | |||
last_epoch=self.global_step - 1, | |||
) | |||
|
|||
# Setup the validation dataset | |||
self.run_validation = "dataset_validation" in cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would prefer the name validation_dataset
In fact, what do you think about organizing the validation arguments in the configs like so:
validation:
dataset:
...
run_every_n_steps: null
max_batches: -1
that way you can just set validation: null
and query that for self.run_validation
Lets do this as a follow up. I can use my script to bulk update. But lets make sure that we all agree on how it should like in the config. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2238 +/- ##
===========================================
- Coverage 65.41% 23.93% -41.49%
===========================================
Files 344 357 +13
Lines 20658 21153 +495
===========================================
- Hits 13514 5062 -8452
- Misses 7144 16091 +8947 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR! It looks simple and the functions make sense!
I added a few comments/ideas. Please push back on what you disagree.
IMO, to approve this, we would need two things:
- Testing, like i suggested in one of the comments. Let me know if you are comfortable running it, otherwise we can help you out
- An example of how the config should look like. The UI should play a big factor on this PR.
@@ -652,6 +675,43 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |||
|
|||
return loss | |||
|
|||
def validate(self, curr_epoch) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Do we have model.eval() somewhere?
usually we want to set the model to .eval mode, because some layers have different behavior, like dropout.
By doing that, we then require less memory, because we only need the forward pass, which allows us to have a higher batch_size --> faster validation step.
I am not sure about the implications it may have to compile/FSDP. For example, compile will have to create a new graph that doesnt require grad, so compile time will have to increase. If the number of graph breaks increase, we may have to manually change the threshold of maximum number of graph breaks. (there is an example of that in one of our RL recipes)
- not all recipes have self._loss_step. We would have to standardize and make sure that they all do, but this requires a different PR,.
IMO, if you have access to >1 GPU, I would encourage you to implement it in lora_distributed with QLoRA config, add .eval(), run it:
- with eval + compile + opt_in_bwd + activation ckpt + activation offloading
- without eval + compile + opt_in_bwd + activation ckpt + activation offloading
If nothing breaks, I would feel more confident in approving it
Ps: we would also have to add mode.train() in the training loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@felipemello1 Thanks for the detailed breakdown and suggestions. Should we also unload the model being trained before loading the eval one? Having just one in memory would allow for bigger batch sizes.
That said, I’m currently constrained on time and not very familiar with the implementation details for this. If I were to take this on, it would likely take me a significant amount of time to get it done properly.
Would you be able to take the lead on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @MaxFrax , completely understandable. Thanks for sharing it.
I dont think that I will have bandwidth soon, but if i do, this PR is a good start.
@Ankur-singh , cc'ing you in case you are looking for more issues to contribute to! :D
Thank you guys!
|
||
# This bit allows to see the loss for each batch. Not sure about step indexing. | ||
log_dict = { | ||
"val_loss": current_loss.item(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should be logging memory/TPS too. If memory is very low, this would show to the user that they can increase bsz. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense. I will definitely look into it.
By the way, when is the previous training batch deallocated from memory? Do I have to deallocate manually? It would be handy to do so before staring the validation step to have more memory available.
@@ -779,6 +839,12 @@ def train(self) -> None: | |||
) | |||
) | |||
|
|||
if ( | |||
self.run_validation | |||
and self.global_step % self.run_val_every_n_steps == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it makes a lot of sense to eval every N steps, but currently a lot of our training logic is based on epochs. I wonder if we should honor this and keep it based on epoch. Maybe users could pass a float, e.g. every 0.5 epochs.
Not 100% sure about this, just brainstorming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to change it as you suggest, I like it better too. I just want to point out that the scheduler has the parameter num_warmup_steps
which contradicts your statement:
our training logic is based on epochs
As user, I'd love the num_warmup_steps to be based on epochs as well.
Thanks @felipemello1 ! Some help on the testing side would be much appreciated.
Should I provide a recipe using the validation dataset? Are we talking about the docs? |
Your PR only contains changes to the recipe. I would encourage you to:
|
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
#1042
Changelog
What are the changes made in this PR?
Adds support to a validation dataset and computes the loss on it after each epoch.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example