Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shapes mismatch triggered at modeling_flax_utils #14215

Closed
javismiles opened this issue Oct 30, 2021 · 5 comments
Closed

Shapes mismatch triggered at modeling_flax_utils #14215

javismiles opened this issue Oct 30, 2021 · 5 comments

Comments

@javismiles
Copy link

Good day,
while using the MiniDalle repo at:
borisdayma/dalle-mini#99

we are suddenly getting this error which was not happening before:
"Trying to load the pretrained weight for ('decoder', 'mid', 'attn_1', 'norm', 'bias') failed: checkpoint has shape (1, 1, 1, 512) which is incompatible with the model shape (512,). Using ignore_mismatched_sizes=True if you really want to load this checkpoint inside this model."

This is being triggered here:
https://huggingface.co/transformers/_modules/transformers/modeling_flax_utils.html

in this area:
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): if key in random_state and state[key].shape != random_state[key].shape: if ignore_mismatched_sizes: mismatched_keys.append((key, state[key].shape, random_state[key].shape)) state[key] = random_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " "Using ignore_mismatched_sizes=True if you really want to load this checkpoint inside this " "model." )

There is a way to avoid halting the execution by going into the code and adding "ignore_mismatched_sizes=True" in the call. However, this does not fix the problem. If we do that, the execution continues but the results obtained by the minidalle model are wrong all washed out and with the wrong colors and contrast (which was not happening some days ago, so something has changed that is producing this problem).
So this seems to be a bug coming from this file. Any tips are super welcome, thank you :)

@LysandreJik
Copy link
Member

cc @patil-suraj

@borisdayma
Copy link
Contributor

borisdayma commented Nov 3, 2021

It seems to work with flax==0.3.5.
My guess is that weights are now being squeezed. Maybe need to reupload a new checkpoint?
Actually here a shape of (512,) seems to make more sense than (1,1,1,512)

@patrickvonplaten
Copy link
Contributor

Do we know which commit in Flax is responsible for this bug?

@HindAB1
Copy link

HindAB1 commented Nov 8, 2021

It seems to work with flax==0.3.5. My guess is that weights are now being squeezed. Maybe need to reupload a new checkpoint? Actually here a shape of (512,) seems to make more sense than (1,1,1,512)

Could you please elaborate ?
where shall i add flax==0.3.5 ?

Thanks!

jonathanfrawley added a commit to jonathanfrawley/dalle-mini that referenced this issue Nov 8, 2021
@patil-suraj
Copy link
Contributor

Left a comment here, borisdayma/dalle-mini#99 (comment)
Closing this issue, since it's not related to transformers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants