Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configure arbitrary frozen modules via config #869

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def __init__(self):
which can be found here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--model.frozen_modules",
type=string_list,
nargs="+",
default=[],
help="""
Comma separated list of modules's FQN to be frozen inside the model. For example:
--model.frozen_modules=`tok_embeddings,layers.0.attention`
""",
)

# optimizer configs
self.parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ def init_distributed(job_config):
)


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False, exclude_frozen: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_frozen:
num_params -= sum(p.numel() for p in model.parameters() if not p.requires_grad)
if exclude_embedding:
num_params -= sum(p.numel() for p in model.tok_embeddings.parameters())
return num_params
Expand Down
24 changes: 23 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from operator import attrgetter
import os
import time
from datetime import timedelta
Expand Down Expand Up @@ -105,21 +106,42 @@ def main(job_config: JobConfig):
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
for module_name in job_config.model.frozen_modules:
try:
module = attrgetter(module_name)(model)
num_params = 0
for param in module.parameters():
param.requires_grad = False
num_params += param.numel()
logger.info(
f"{color.red}Freezing {num_params:,} parameters in {color.magenta}model.{module_name}.{color.reset}"
)
except AttributeError:
logger.warning(
f"""Module {color.magenta}{module_name}{color.reset} is set to be frozen but it does not exist in the model.
Make sure the module name is a valid submodule like `tok_embeddings, layers.0.attention` etc.
"""
)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# log model size
model_param_count = utils.get_num_params(model)
model_trainable_param_count = utils.get_num_params(model, exclude_frozen=True)
if model_trainable_param_count != model_param_count:
trainable_log_str = f"{color.green}of which {model_trainable_param_count:,} parameters are trainable"
else:
trainable_log_str = f"{color.green}all are trainable."
num_flop_per_token = utils.get_num_flop_per_token(
utils.get_num_params(model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
logger.info(
f"{color.blue}Model {train_spec.name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
f"{color.red}size: {model_param_count:,} total parameters {trainable_log_str}{color.reset}"
)

# loss function to be shared by Pipeline Parallel and SPMD training
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./tests/assets/test_tiktoken.model"
# converters = "float8"
# frozen_modules = "tok_embeddings,layers.0.attention"

[optimizer]
name = "AdamW"
Expand Down