Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use nn.Sequential to remove python control flow from autoencoder up/downsampling #33

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

technillogue
Copy link
Contributor

@technillogue technillogue commented Oct 10, 2024

the current autoencoder implementation causes graph breaks, likely due to python control flow. the hs list constructed in the encoder is also egregious. performance improvement numbers TBD this does not make an improvement with normal eager torch, if we decide to compile for the autoencoder it will fix the graph breaks with tensorrt, but torch.compile may see through this structure.

@technillogue technillogue requested a review from yorickvP October 10, 2024 05:07
@technillogue technillogue force-pushed the syl/unroll-encoder-loops branch from 31687a3 to 2a6dd2e Compare October 10, 2024 05:09
@yorickvP
Copy link
Contributor

  • DId you check that it has the same result?
  • Interestingly, I did torch._dynamo.explain on the (FP8) decoder and it said no graph breaks.
  • Would be nice to also have this for FP8

@daanelson
Copy link
Collaborator

@yorickvP interesting that there's no graph breaks; wonder if the compiler is smart enough to realize that the if statements will always evaluate to true/false depending on loop iteration.

agreed that we should get perf numbers & test that outputs are unchanged

@technillogue
Copy link
Contributor Author

unfortunately this changes the state_dict keys:

Got 180 missing keys:
encoder.down.0.norm1.weight
encoder.down.0.norm1.bias
encoder.down.0.conv1.weight
encoder.down.0.conv1.bias

encoder.down.0.norm2.weight encoder.down.0.norm2.bias encoder.down.0.conv2.weight encoder.down.0.conv2.bias encoder.down.1.norm1.weight encoder.down.1.norm1.bias encoder.down.1.conv1.weight encoder.down.1.conv1.bias encoder.down.1.norm2.weight encoder.down.1.norm2.bias encoder.down.1.conv2.weight encoder.down.1.conv2.bias encoder.down.2.conv.weight encoder.down.2.conv.bias encoder.down.3.norm1.weight encoder.down.3.norm1.bias encoder.down.3.conv1.weight encoder.down.3.conv1.bias encoder.down.3.norm2.weight encoder.down.3.norm2.bias encoder.down.3.conv2.weight encoder.down.3.conv2.bias encoder.down.3.nin_shortcut.weight encoder.down.3.nin_shortcut.bias encoder.down.4.norm1.weight encoder.down.4.norm1.bias encoder.down.4.conv1.weight encoder.down.4.conv1.bias encoder.down.4.norm2.weight encoder.down.4.norm2.bias encoder.down.4.conv2.weight encoder.down.4.conv2.bias encoder.down.5.conv.weight encoder.down.5.conv.bias encoder.down.6.norm1.weight encoder.down.6.norm1.bias encoder.down.6.conv1.weight encoder.down.6.conv1.bias encoder.down.6.norm2.weight encoder.down.6.norm2.bias encoder.down.6.conv2.weight encoder.down.6.conv2.bias encoder.down.6.nin_shortcut.weight encoder.down.6.nin_shortcut.bias encoder.down.7.norm1.weight encoder.down.7.norm1.bias encoder.down.7.conv1.weight encoder.down.7.conv1.bias encoder.down.7.norm2.weight encoder.down.7.norm2.bias encoder.down.7.conv2.weight encoder.down.7.conv2.bias encoder.down.8.conv.weight encoder.down.8.conv.bias encoder.down.9.norm1.weight encoder.down.9.norm1.bias encoder.down.9.conv1.weight encoder.down.9.conv1.bias encoder.down.9.norm2.weight encoder.down.9.norm2.bias encoder.down.9.conv2.weight encoder.down.9.conv2.bias encoder.down.10.norm1.weight encoder.down.10.norm1.bias encoder.down.10.conv1.weight encoder.down.10.conv1.bias encoder.down.10.norm2.weight encoder.down.10.norm2.bias encoder.down.10.conv2.weight encoder.down.10.conv2.bias decoder.up.0.norm1.weight decoder.up.0.norm1.bias decoder.up.0.conv1.weight decoder.up.0.conv1.bias decoder.up.0.norm2.weight decoder.up.0.norm2.bias decoder.up.0.conv2.weight decoder.up.0.conv2.bias decoder.up.0.nin_shortcut.weight decoder.up.0.nin_shortcut.bias decoder.up.1.norm1.weight decoder.up.1.norm1.bias decoder.up.1.conv1.weight decoder.up.1.conv1.bias decoder.up.1.norm2.weight decoder.up.1.norm2.bias decoder.up.1.conv2.weight decoder.up.1.conv2.bias decoder.up.2.norm1.weight decoder.up.2.norm1.bias decoder.up.2.conv1.weight decoder.up.2.conv1.bias decoder.up.2.norm2.weight decoder.up.2.norm2.bias decoder.up.2.conv2.weight decoder.up.2.conv2.bias decoder.up.3.norm1.weight decoder.up.3.norm1.bias decoder.up.3.conv1.weight decoder.up.3.conv1.bias decoder.up.3.norm2.weight decoder.up.3.norm2.bias decoder.up.3.conv2.weight decoder.up.3.conv2.bias decoder.up.3.nin_shortcut.weight decoder.up.3.nin_shortcut.bias decoder.up.4.norm1.weight decoder.up.4.norm1.bias decoder.up.4.conv1.weight decoder.up.4.conv1.bias decoder.up.4.norm2.weight decoder.up.4.norm2.bias decoder.up.4.conv2.weight decoder.up.4.conv2.bias decoder.up.5.norm1.weight decoder.up.5.norm1.bias decoder.up.5.conv1.weight decoder.up.5.conv1.bias decoder.up.5.norm2.weight decoder.up.5.norm2.bias decoder.up.5.conv2.weight decoder.up.5.conv2.bias decoder.up.6.conv.weight decoder.up.6.conv.bias decoder.up.7.norm1.weight decoder.up.7.norm1.bias decoder.up.7.conv1.weight decoder.up.7.conv1.bias decoder.up.7.norm2.weight decoder.up.7.norm2.bias decoder.up.7.conv2.weight decoder.up.7.conv2.bias decoder.up.8.norm1.weight decoder.up.8.norm1.bias decoder.up.8.conv1.weight decoder.up.8.conv1.bias decoder.up.8.norm2.weight decoder.up.8.norm2.bias decoder.up.8.conv2.weight decoder.up.8.conv2.bias decoder.up.9.norm1.weight decoder.up.9.norm1.bias decoder.up.9.conv1.weight decoder.up.9.conv1.bias decoder.up.9.norm2.weight decoder.up.9.norm2.bias decoder.up.9.conv2.weight decoder.up.9.conv2.bias decoder.up.10.conv.weight decoder.up.10.conv.bias decoder.up.11.norm1.weight decoder.up.11.norm1.bias decoder.up.11.conv1.weight decoder.up.11.conv1.bias decoder.up.11.norm2.weight decoder.up.11.norm2.bias decoder.up.11.conv2.weight decoder.up.11.conv2.bias decoder.up.12.norm1.weight decoder.up.12.norm1.bias decoder.up.12.conv1.weight decoder.up.12.conv1.bias decoder.up.12.norm2.weight decoder.up.12.norm2.bias decoder.up.12.conv2.weight decoder.up.12.conv2.bias decoder.up.13.norm1.weight decoder.up.13.norm1.bias decoder.up.13.conv1.weight decoder.up.13.conv1.bias decoder.up.13.norm2.weight decoder.up.13.norm2.bias decoder.up.13.conv2.weight decoder.up.13.conv2.bias decoder.up.14.conv.weight decoder.up.14.conv.bias

Got 180 unexpected keys:
encoder.down.0.block.0.conv1.bias
encoder.down.0.block.0.conv1.weight
encoder.down.0.block.0.conv2.bias
encoder.down.0.block.0.conv2.weight

encoder.down.0.block.0.norm1.bias encoder.down.0.block.0.norm1.weight encoder.down.0.block.0.norm2.bias encoder.down.0.block.0.norm2.weight encoder.down.0.block.1.conv1.bias encoder.down.0.block.1.conv1.weight encoder.down.0.block.1.conv2.bias encoder.down.0.block.1.conv2.weight encoder.down.0.block.1.norm1.bias encoder.down.0.block.1.norm1.weight encoder.down.0.block.1.norm2.bias encoder.down.0.block.1.norm2.weight encoder.down.0.downsample.conv.bias encoder.down.0.downsample.conv.weight encoder.down.1.block.0.conv1.bias encoder.down.1.block.0.conv1.weight encoder.down.1.block.0.conv2.bias encoder.down.1.block.0.conv2.weight encoder.down.1.block.0.nin_shortcut.bias encoder.down.1.block.0.nin_shortcut.weight encoder.down.1.block.0.norm1.bias encoder.down.1.block.0.norm1.weight encoder.down.1.block.0.norm2.bias encoder.down.1.block.0.norm2.weight encoder.down.1.block.1.conv1.bias encoder.down.1.block.1.conv1.weight encoder.down.1.block.1.conv2.bias encoder.down.1.block.1.conv2.weight encoder.down.1.block.1.norm1.bias encoder.down.1.block.1.norm1.weight encoder.down.1.block.1.norm2.bias encoder.down.1.block.1.norm2.weight encoder.down.1.downsample.conv.bias encoder.down.1.downsample.conv.weight encoder.down.2.block.0.conv1.bias encoder.down.2.block.0.conv1.weight encoder.down.2.block.0.conv2.bias encoder.down.2.block.0.conv2.weight encoder.down.2.block.0.nin_shortcut.bias encoder.down.2.block.0.nin_shortcut.weight encoder.down.2.block.0.norm1.bias encoder.down.2.block.0.norm1.weight encoder.down.2.block.0.norm2.bias encoder.down.2.block.0.norm2.weight encoder.down.2.block.1.conv1.bias encoder.down.2.block.1.conv1.weight encoder.down.2.block.1.conv2.bias encoder.down.2.block.1.conv2.weight encoder.down.2.block.1.norm1.bias encoder.down.2.block.1.norm1.weight encoder.down.2.block.1.norm2.bias encoder.down.2.block.1.norm2.weight encoder.down.2.downsample.conv.bias encoder.down.2.downsample.conv.weight encoder.down.3.block.0.conv1.bias encoder.down.3.block.0.conv1.weight encoder.down.3.block.0.conv2.bias encoder.down.3.block.0.conv2.weight encoder.down.3.block.0.norm1.bias encoder.down.3.block.0.norm1.weight encoder.down.3.block.0.norm2.bias encoder.down.3.block.0.norm2.weight encoder.down.3.block.1.conv1.bias encoder.down.3.block.1.conv1.weight encoder.down.3.block.1.conv2.bias encoder.down.3.block.1.conv2.weight encoder.down.3.block.1.norm1.bias encoder.down.3.block.1.norm1.weight encoder.down.3.block.1.norm2.bias encoder.down.3.block.1.norm2.weight decoder.up.0.block.0.conv1.bias decoder.up.0.block.0.conv1.weight decoder.up.0.block.0.conv2.bias decoder.up.0.block.0.conv2.weight decoder.up.0.block.0.nin_shortcut.bias decoder.up.0.block.0.nin_shortcut.weight decoder.up.0.block.0.norm1.bias decoder.up.0.block.0.norm1.weight decoder.up.0.block.0.norm2.bias decoder.up.0.block.0.norm2.weight decoder.up.0.block.1.conv1.bias decoder.up.0.block.1.conv1.weight decoder.up.0.block.1.conv2.bias decoder.up.0.block.1.conv2.weight decoder.up.0.block.1.norm1.bias decoder.up.0.block.1.norm1.weight decoder.up.0.block.1.norm2.bias decoder.up.0.block.1.norm2.weight decoder.up.0.block.2.conv1.bias decoder.up.0.block.2.conv1.weight decoder.up.0.block.2.conv2.bias decoder.up.0.block.2.conv2.weight decoder.up.0.block.2.norm1.bias decoder.up.0.block.2.norm1.weight decoder.up.0.block.2.norm2.bias decoder.up.0.block.2.norm2.weight decoder.up.1.block.0.conv1.bias decoder.up.1.block.0.conv1.weight decoder.up.1.block.0.conv2.bias decoder.up.1.block.0.conv2.weight decoder.up.1.block.0.nin_shortcut.bias decoder.up.1.block.0.nin_shortcut.weight decoder.up.1.block.0.norm1.bias decoder.up.1.block.0.norm1.weight decoder.up.1.block.0.norm2.bias decoder.up.1.block.0.norm2.weight decoder.up.1.block.1.conv1.bias decoder.up.1.block.1.conv1.weight decoder.up.1.block.1.conv2.bias decoder.up.1.block.1.conv2.weight decoder.up.1.block.1.norm1.bias decoder.up.1.block.1.norm1.weight decoder.up.1.block.1.norm2.bias decoder.up.1.block.1.norm2.weight decoder.up.1.block.2.conv1.bias decoder.up.1.block.2.conv1.weight decoder.up.1.block.2.conv2.bias decoder.up.1.block.2.conv2.weight decoder.up.1.block.2.norm1.bias decoder.up.1.block.2.norm1.weight decoder.up.1.block.2.norm2.bias decoder.up.1.block.2.norm2.weight decoder.up.1.upsample.conv.bias decoder.up.1.upsample.conv.weight decoder.up.2.block.0.conv1.bias decoder.up.2.block.0.conv1.weight decoder.up.2.block.0.conv2.bias decoder.up.2.block.0.conv2.weight decoder.up.2.block.0.norm1.bias decoder.up.2.block.0.norm1.weight decoder.up.2.block.0.norm2.bias decoder.up.2.block.0.norm2.weight decoder.up.2.block.1.conv1.bias decoder.up.2.block.1.conv1.weight decoder.up.2.block.1.conv2.bias decoder.up.2.block.1.conv2.weight decoder.up.2.block.1.norm1.bias decoder.up.2.block.1.norm1.weight decoder.up.2.block.1.norm2.bias decoder.up.2.block.1.norm2.weight decoder.up.2.block.2.conv1.bias decoder.up.2.block.2.conv1.weight decoder.up.2.block.2.conv2.bias decoder.up.2.block.2.conv2.weight decoder.up.2.block.2.norm1.bias decoder.up.2.block.2.norm1.weight decoder.up.2.block.2.norm2.bias decoder.up.2.block.2.norm2.weight decoder.up.2.upsample.conv.bias decoder.up.2.upsample.conv.weight decoder.up.3.block.0.conv1.bias decoder.up.3.block.0.conv1.weight decoder.up.3.block.0.conv2.bias decoder.up.3.block.0.conv2.weight decoder.up.3.block.0.norm1.bias decoder.up.3.block.0.norm1.weight decoder.up.3.block.0.norm2.bias decoder.up.3.block.0.norm2.weight decoder.up.3.block.1.conv1.bias decoder.up.3.block.1.conv1.weight decoder.up.3.block.1.conv2.bias decoder.up.3.block.1.conv2.weight decoder.up.3.block.1.norm1.bias decoder.up.3.block.1.norm1.weight decoder.up.3.block.1.norm2.bias decoder.up.3.block.1.norm2.weight decoder.up.3.block.2.conv1.bias decoder.up.3.block.2.conv1.weight decoder.up.3.block.2.conv2.bias decoder.up.3.block.2.conv2.weight decoder.up.3.block.2.norm1.bias decoder.up.3.block.2.norm1.weight decoder.up.3.block.2.norm2.bias decoder.up.3.block.2.norm2.weight decoder.up.3.upsample.conv.bias decoder.up.3.upsample.conv.weight

The unexpected keys are the one in the ae.sft file, the expected keys are the ones from using flat Sequential[ResnetBlock] instead of ModuleList[ModuleList[ResnetBlock]]. there's probably a way to remap the state_dict keys but it would be easier if we could modify the file

@technillogue technillogue force-pushed the syl/unroll-encoder-loops branch 5 times, most recently from 92c664c to 35b343d Compare October 15, 2024 23:54
@technillogue technillogue force-pushed the syl/unroll-encoder-loops branch from 35b343d to 7333f34 Compare October 15, 2024 23:58
@technillogue technillogue mentioned this pull request Nov 8, 2024
@yorickvP
Copy link
Contributor

yorickvP commented Dec 3, 2024

It might be nicer to instead override load_state_dict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants