Skip to content

Commit

Permalink
Configure arbitrary frozen modules via config
Browse files Browse the repository at this point in the history
  • Loading branch information
lkhphuc committed Feb 20, 2025
1 parent 7a34e3c commit 526513c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
10 changes: 10 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def __init__(self):
which can be found here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--model.frozen_modules",
type=string_list,
nargs="+",
default=[],
help="""
Comma separated list of modules's FQN to be frozen inside the model. For example:
--model.frozen_modules=`tok_embeddings,layers.0.attention`
""",
)

# optimizer configs
self.parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ def init_distributed(job_config):
)


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False, exclude_frozen: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_frozen:
num_params -= sum(p.numel() for p in model.parameters() if not p.requires_grad)
if exclude_embedding:
num_params -= sum(p.numel() for p in model.tok_embeddings.parameters())
return num_params
Expand Down
24 changes: 23 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from operator import attrgetter
import os
import time
from datetime import timedelta
Expand Down Expand Up @@ -105,21 +106,42 @@ def main(job_config: JobConfig):
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
for module_name in job_config.model.frozen_modules:
try:
module = attrgetter(module_name)(model)
num_params = 0
for param in module.parameters():
param.requires_grad = False
num_params += param.numel()
logger.info(
f"{color.red}Freezing {num_params:,} parameters in {color.magenta}model.{module_name}.{color.reset}"
)
except AttributeError:
logger.warning(
f"""Module {color.magenta}{module_name}{color.reset} is set to be frozen but it does not exist in the model.
Make sure the module name is a valid submodule like `tok_embeddings, layers.0.attention` etc.
"""
)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# log model size
model_param_count = utils.get_num_params(model)
model_trainable_param_count = utils.get_num_params(model, exclude_frozen=True)
if model_trainable_param_count != model_param_count:
trainable_log_str = f"{color.green}of which {model_trainable_param_count:,} parameters are trainable"
else:
trainable_log_str = f"{color.green}all are trainable."
num_flop_per_token = utils.get_num_flop_per_token(
utils.get_num_params(model, exclude_embedding=True),
model_config,
job_config.training.seq_len,
)
logger.info(
f"{color.blue}Model {train_spec.name} {job_config.model.flavor} "
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
f"{color.red}size: {model_param_count:,} total parameters {trainable_log_str}{color.reset}"
)

# loss function to be shared by Pipeline Parallel and SPMD training
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./tests/assets/test_tiktoken.model"
# converters = "float8"
# frozen_modules = "tok_embeddings,layers.0.attention"

[optimizer]
name = "AdamW"
Expand Down

0 comments on commit 526513c

Please sign in to comment.