Skip to content

🚨🚨🚨 Enforce single model initialization#21431

Merged
sgugger merged 14 commits intomainfrom
init_fixes
Feb 9, 2023
Merged

🚨🚨🚨 Enforce single model initialization#21431
sgugger merged 14 commits intomainfrom
init_fixes

Conversation

@sgugger
Copy link
Copy Markdown
Collaborator

@sgugger sgugger commented Feb 2, 2023

What does this PR do?

There are currently three problems with the mode inits:

Problem 1: When not using the fast init (so in practice when using the model constructor or AutoXxx.from_config instead of from_pretrained) weights are initialized multiple times. @stas00 showed the example of OPTForCausalLM where we have a call to post_init() three times: in OPTForCausalLM, OptModel and OptDecoder. Each of those calls launches a recursive call of _init_weights to all submodules of the model, so this makes three inits.

Problem 2: The fast init (of random weights of the head in from_pretrained) and non-fast init (as above) are not always equivalent. This is because in from_pretrained init is done on calling _init_weights only on leaf modules with weights not present in the checkpoint, but sometimes _init_weights contains class checks for bigger modules (here is one example in OneFormer)

Problem 3: Some of the models have _init_weights function that will initialize the same weights with two different ways. We can take back this example in OneFormer which initializes a weight that is a Conv2D, but _init_weights is applied recursively, so that Conv2D will also be initialized here with a different rule.

This PR should solve these three problems with one stone by changing slightly the _init_weights function to look for a private _is_hf_initialized attribute in the module and skip the init if it's there and True. Of course when initializing a module, this private attribute is set to True after the initialization is done.

This PR gets the 🚨🚨🚨 sign because it might break user's code if they were relying on the (buggy) init of composite models: if a model has an encoder or backbone that is initialized differently from the rest, the init of the encoder/backbone was previously erased by the bigger model init.

@sgugger sgugger requested review from LysandreJik and stas00 February 2, 2023 20:38
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Feb 2, 2023

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Feb 3, 2023

@stas00 In initial discussions with @LysandreJik , he mentioned he preferred not having a wrapper. Though the argument about init weights code in the wild is a sound one, so showed how it could look like with the last two commits.

@LysandreJik
Copy link
Copy Markdown
Member

Thanks for the PR, and for showing the two options! I feel like the wrapper is a little bit magical, but would make contributions simpler while reducing the complexity of the code.

I would go with the wrapper, if possible.

elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
Copy link
Copy Markdown
Contributor

@stas00 stas00 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once all instances of OneFormerTransformerDecoder submodule will have _is_hf_initialized = True this code would already never run, no? as _init_weights won't get called on this sub-module anymore.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal with this is to avoid the module.query_input_projections be initialized another time since it is a Conv2d and there is a path for Conv2d in the succession of tests here. This is an example of fix for problem 2 in the PR description.

Copy link
Copy Markdown
Contributor

@stas00 stas00 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I can see now why, it's because of as you said later in this function we are having:

elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):

Copy link
Copy Markdown
Contributor

@stas00 stas00 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an idea - pass to _init_weights param key name in addition to module name where possible? that way one could also if on the key name and shortcut the "switch case"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so instead of - init Conv2d this way unless the param belongs to this parent module - for this param key only use this init.

Copy link
Copy Markdown
Contributor

@stas00 stas00 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I guess it'd be difficult if the key isn't always fully qualified depending on how the model was initialized - full stack or say just decoder - perhaps just checking the last segment of the param name would be enough of context?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require changing the signature of all _init_weights, so we're back to changing all models ;-) I think it's probably easier this way even if it looks more convoluted at first glance.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, didn't think of that one! you're correct, Sylvain.

It'd have been useful if the init function returns a list of params it touched. the the outside care taker could do the accounting automatically. But again this adds more complexity.

Copy link
Copy Markdown
Contributor

@stas00 stas00 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me of the Deepspeed external parameter special case that was originally an issue for the same reason.

So let's please document this special case using the full example with 2 isinstance branches to show what to do when a sub-module inits weights that are outside of its immediate descendants.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add it to the add model doc.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Feb 3, 2023

Thank you for making it simpler for the end user, Sylvain - I will test this today on m4 and get back to you.

Copy link
Copy Markdown
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested this on the m4 issue that started this whole investigation.

This PR solves the problem. My init tests that check the expected mean and variance now pass!

Thank you, Sylvain!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leftover in BART clashes with the new logic and testing. It is fixed here and in several copies (and in practice does exactly the same since weights are only initialized once).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes the weird hack for init of those modules (see below).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded values on a random model, which is now initialized differently with the only one init rule.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes this to get no init in the subconfigs as well.

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Feb 9, 2023

Thank you for doing a massive adjustment work and the explanations, Sylvain!

This is hard work and very awesome for everybody to benefit from!

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Feb 9, 2023

Last failing test is flaky so this is good for final review!

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change! Thanks for spending time on this, it's very nice to be sure that the initialisation is correct across methods and models.

LGTM!

The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
random weights, thus making sure that the `init()` methods of all components works.

Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the docs! Thanks for spending time on them, it's worthwhile.

config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)

def test_save_load_fast_init_from_base(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing to remove the special case

@sgugger sgugger changed the title Enforce single model initialization 🚨🚨🚨 Enforce single model initialization Feb 9, 2023
@sgugger sgugger merged commit 04b2f13 into main Feb 9, 2023
@sgugger sgugger deleted the init_fixes branch February 9, 2023 20:46
@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Feb 9, 2023

so it didn't make it into https://github.com/huggingface/transformers/releases/tag/v4.26.1, right?

do you know if you plan another hotfix release in the future or plan to wait for 4.27.0?

Asking as I'm needing to anchor requirements on this fix for m4 where I found this bug.

@sgugger
Copy link
Copy Markdown
Collaborator Author

sgugger commented Feb 10, 2023

This won't be until 4.27.0 as it could come with bugs we need to fix (and it's not a regression fix so won't go in a patch).

@stas00
Copy link
Copy Markdown
Contributor

stas00 commented Feb 10, 2023

Thank you for the clarity, Sylvain. 4.27.0 it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants