-
Notifications
You must be signed in to change notification settings - Fork 543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Code for using the ImageNet pretrained model #146
Comments
I have this problem too. Thanks for tip ! |
Hi, Sorry for the confusion. The resnet (nn.Module file) used in this repo was only for CIFAR input, i.e., 32x32. The weights for ImageNet we provided here is for input of 224x224, so it can only be loaded with pytorch official definition of ResNet, which takes 224x224 as input. Historically, we firstly release this repo for CIFAR-10/100, so we define resnet for only 32x32 input. Later on, I trained a SupCon ImageNet model with my other code base, and shared the weights in this repo. So it caused this confusion. |
@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet. |
Glad you figured it out much earlier, and thank you for sharing it! |
I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.
Loading weights
state_dict=torch.load("supcon_official.pth",'cpu')
Correcting the terms properly.
state_dict=state_dict['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("encoder.", "")
new_state_dict[k] = v
state_dict = new_state_dict
Using the standard pytorch resnet50
model = resnet50()
del model.fc
model.fc = nn.Identity()
Dont need this
state_dict.pop("head.0.weight", None)
state_dict.pop("head.0.bias", None)
state_dict.pop("head.2.weight", None)
state_dict.pop("head.2.bias", None)
This should do the trick
model.load_state_dict(state_dict,strict=True)
The text was updated successfully, but these errors were encountered: