Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failed: checkpoint has shape (1, 1, 1, 512) which is incompatible with the model shape (512,) #99

Closed
Mahyar-Ali opened this issue Oct 30, 2021 · 11 comments · Fixed by patil-suraj/vqgan-jax#4

Comments

@Mahyar-Ali
Copy link

Mahyar-Ali commented Oct 30, 2021

When trying to load the VQModel using from_pretrained, it fails and generates an error message.

# make sure we use compatible versions
VQGAN_REPO = 'flax-community/vqgan_f16_16384'
VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'

# set up VQGAN
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
DEBUG:filelock:Attempting to acquire lock 139630901715792 on /root/.cache/huggingface/transformers/9d51ab91692e9c42f82e628f71bc27d13685dba2b0b28841dd1fb163e861cb4f.de091ef3cdb74c7d4cc2da0510c99e9ae385befed9ae4473a3191b6d93da9edd.lock
DEBUG:filelock:Lock 139630901715792 acquired on /root/.cache/huggingface/transformers/9d51ab91692e9c42f82e628f71bc27d13685dba2b0b28841dd1fb163e861cb4f.de091ef3cdb74c7d4cc2da0510c99e9ae385befed9ae4473a3191b6d93da9edd.lock
Downloading: 100%
433/433 [00:00<00:00, 16.0kB/s]
DEBUG:filelock:Attempting to release lock 139630901715792 on /root/.cache/huggingface/transformers/9d51ab91692e9c42f82e628f71bc27d13685dba2b0b28841dd1fb163e861cb4f.de091ef3cdb74c7d4cc2da0510c99e9ae385befed9ae4473a3191b6d93da9edd.lock
DEBUG:filelock:Lock 139630901715792 released on /root/.cache/huggingface/transformers/9d51ab91692e9c42f82e628f71bc27d13685dba2b0b28841dd1fb163e861cb4f.de091ef3cdb74c7d4cc2da0510c99e9ae385befed9ae4473a3191b6d93da9edd.lock
DEBUG:filelock:Attempting to acquire lock 139630891372880 on /root/.cache/huggingface/transformers/98190fe7878f67d122ca4539eb6b459bc77dd757fe54e8cd774952c7f32bab79.8efdbd1ba9de17901e4252a40a48003335296da3f7584ca3cac46bff5d9d142b.lock
DEBUG:filelock:Lock 139630891372880 acquired on /root/.cache/huggingface/transformers/98190fe7878f67d122ca4539eb6b459bc77dd757fe54e8cd774952c7f32bab79.8efdbd1ba9de17901e4252a40a48003335296da3f7584ca3cac46bff5d9d142b.lock
Downloading: 100%
290M/290M [00:05<00:00, 58.7MB/s]
DEBUG:filelock:Attempting to release lock 139630891372880 on /root/.cache/huggingface/transformers/98190fe7878f67d122ca4539eb6b459bc77dd757fe54e8cd774952c7f32bab79.8efdbd1ba9de17901e4252a40a48003335296da3f7584ca3cac46bff5d9d142b.lock
DEBUG:filelock:Lock 139630891372880 released on /root/.cache/huggingface/transformers/98190fe7878f67d122ca4539eb6b459bc77dd757fe54e8cd774952c7f32bab79.8efdbd1ba9de17901e4252a40a48003335296da3f7584ca3cac46bff5d9d142b.lock
Working with z of shape (1, 256, 16, 16) = 65536 dimensions.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-00923b8dba5a> in <module>()
      1 # set up VQGAN
----> 2 vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)

/usr/local/lib/python3.7/dist-packages/transformers/modeling_flax_utils.py in from_pretrained(cls, pretrained_model_name_or_path, dtype, *model_args, **kwargs)
    402                 else:
    403                     raise ValueError(
--> 404                         f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
    405                         f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
    406                         "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "

ValueError: Trying to load the pretrained weight for ('decoder', 'mid', 'attn_1', 'norm', 'bias') failed: checkpoint has shape (1, 1, 1, 512) which is incompatible with the model shape (512,). Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this model.
@javismiles
Copy link

Good day, we detected this same issue/ bug today.
All was working great until you did your last update a few days ago.
There is a way to avoid halting the execution by going into the code and adding "ignore_mismatched_sizes=True" in the call. However, this does not fix the problem. If we do that, the execution continues but the results obtained by the model are terrible, all washed out and with the wrong colors and contrast.
We definitely need a fix for this bug, thank you very much :) Till you latest changes the model was working great, obtaining beautiful results. We hope that you can fix this soon, thank you

@borisdayma
Copy link
Owner

borisdayma commented Oct 30, 2021

I'm not really sure where it comes from.
This model is from https://github.com/patil-suraj/vqgan-jax which does not appear to have any recent updates but maybe you could report it there?
Otherwise it could come from https://github.com/huggingface/transformers so I would check issues and recent changes maybe? You could also just try to install a previous version of transformers.

@javismiles
Copy link

@borisdayma @Mahyar-Ali
Thank you Boris, in fact this is coming from here:

https://huggingface.co/transformers/_modules/transformers/modeling_flax_utils.html

in this area:
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): if key in random_state and state[key].shape != random_state[key].shape: if ignore_mismatched_sizes: mismatched_keys.append((key, state[key].shape, random_state[key].shape)) state[key] = random_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " "Using ignore_mismatched_sizes=True if you really want to load this checkpoint inside this " "model." )

@adesgautam
Copy link

Still getting the same issue here. Hoping for a fix.
BTW thanks for the awesome repo.

@borisdayma
Copy link
Owner

There's a suggested fix in the linked huggingface/transformers issue but for the real fix someone needs to find at which flax commit it broke.

jonathanfrawley added a commit to jonathanfrawley/dalle-mini that referenced this issue Nov 8, 2021
@patil-suraj
Copy link
Collaborator

patil-suraj commented Nov 8, 2021

The issue is with GroupNorm layer, This PR google/flax#1553 introduced a refactorization for normalization layers in flax which changes the feature_shape for GroupNorm params. (Is this intentional @jheek ?).

To fix this, we could set the min flax version to >=0.3.6 and update the weights on the hub to fix the shape of norm parameters. Users who still want to use the previous flax version can still download the old weights by specifying the commit id, as the hub is git-based.
wdyt ?

@borisdayma
Copy link
Owner

For people having this issue, a temporary fix is to install a previous version of flax: pip install flax==0.3.5

@jheek
Copy link

jheek commented Nov 8, 2021

The issue is with GroupNorm layer, This PR google/flax#1553 introduced a refactorization for normalization layers in flax which changes the feature_shape for GroupNorm params. (Is this intentional @jheek ?).

I didn't realize this when doing the refactor but its a feature not a bug. The shape is now consistent with other normalization layers and avoids unnecessary constraining the inputs rank for GroupNorm layers.

I'm sorry for the checkpoint inconsistency, those are very annoying. This particular one can usually be resolved with a jax.tree_map(jnp.squeeze, params) btw.

@borisdayma
Copy link
Owner

Thanks @jheek for the quick answer! No problem, we can update our checkpoint.

@patil-suraj , do we need to update anything in your repo or can I just recreate a new checkpoint with the most recent flax version?

@patil-suraj
Copy link
Collaborator

@patil-suraj , do we need to update anything in your repo or can I just recreate a new checkpoint with the most recent flax version?

Yes, need to update the flax version in setup.py and fix the conversion script. Will take care of it.

@borisdayma
Copy link
Owner

This has been fixed, you can use dalle-mini/vqgan_imagenet_f16_16384 which is the updated checkpoint.

It's different from the original we use in the inference notebook because that one was fine-tuned on other images.
We will update the inference notebook to use this new checkpoint once we have a new model compatible with it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants