Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use VE's patched load_from_state_dict on TPU for loading empty weights #386

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 103 additions & 4 deletions modeling/lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
import _codecs
import os
from typing import Any, Callable, Dict, Optional, Tuple, Type
import accelerate

from torch.nn import Module
from torch.storage import UntypedStorage
Expand All @@ -70,6 +69,12 @@
except ModuleNotFoundError:
HAS_SAFETENSORS = False

try:
import accelerate
USE_TPU_EMPTY_MODULE_METHOD = False
except ModuleNotFoundError:
USE_TPU_EMPTY_MODULE_METHOD = True

import utils
from logger import logger

Expand Down Expand Up @@ -400,6 +405,72 @@ def new_pickle_load(*args, **kwargs):
pickle.Unpickler = old_unpickler
pickle.load = old_pickle_load

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'
.format(key, type(input_param)))
continue

# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]

if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
#param.copy_(input_param)
new_param = torch.nn.Parameter(input_param, requires_grad=param.requires_grad) # This line is new
if name in self._parameters: # This line is new
self._parameters[name] = new_param # This line is new
if name in persistent_buffers: # This line is new
self._buffers[name] = new_param # This line is new
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'
.format(key, param.size(), input_param.size(), ex.args))
elif strict:
missing_keys.append(key)

extra_state_key = prefix + "_extra_state"
if hasattr(Module, "set_extra_state") and getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: # if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)

if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

@contextlib.contextmanager
def use_lazy_load(
Expand Down Expand Up @@ -453,8 +524,30 @@ def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
patch_safetensors(callback)

if dematerialized_modules:
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
# Most devices can just use Accelerate's implementation, but the Transformers on
# the TPU complains about emptied weights unless we use VE's custom patches
if not USE_TPU_EMPTY_MODULE_METHOD:
init_empty_weights = accelerate.init_empty_weights()
init_empty_weights.__enter__()
else:
old_linear_init = torch.nn.Linear.__init__
old_embedding_init = torch.nn.Embedding.__init__
old_layernorm_init = torch.nn.LayerNorm.__init__

def linear_init(self, *args, device=None, **kwargs):
return old_linear_init(self, *args, device="meta", **kwargs)

def embedding_init(self, *args, device=None, **kwargs):
return old_embedding_init(self, *args, device="meta", **kwargs)

def layernorm_init(self, *args, device=None, **kwargs):
return old_layernorm_init(self, *args, device="meta", **kwargs)

torch.nn.Linear.__init__ = linear_init
torch.nn.Embedding.__init__ = embedding_init
torch.nn.LayerNorm.__init__ = layernorm_init
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
torch.nn.Module._load_from_state_dict = _load_from_state_dict

with use_custom_unpickler(_LazyUnpickler):
yield True
Expand All @@ -469,7 +562,13 @@ def torch_load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
)

if dematerialized_modules:
init_empty_weights.__exit__(None, None, None)
if not USE_TPU_EMPTY_MODULE_METHOD:
init_empty_weights.__exit__(None, None, None)
else:
torch.nn.Linear.__init__ = old_linear_init
torch.nn.Embedding.__init__ = old_embedding_init
torch.nn.LayerNorm.__init__ = old_layernorm_init
torch.nn.Module._load_from_state_dict = old_load_from_state_dict


def post_load_cleanup() -> None:
Expand Down
3 changes: 1 addition & 2 deletions requirements_mtj.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ flask_compress
ijson
ftfy
pydub
sentencepiece
accelerate==0.18.0
sentencepiece