diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 96ebb839a..cd9468557 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -209,6 +209,16 @@ def __init__(self): which can be found here: https://github.com/pytorch/ao """, ) + self.parser.add_argument( + "--model.frozen_modules", + type=string_list, + nargs="+", + default=[], + help=""" + Comma separated list of modules's FQN to be frozen inside the model. For example: + --model.frozen_modules=`tok_embeddings,layers.0.attention` + """, + ) # optimizer configs self.parser.add_argument( diff --git a/torchtitan/utils.py b/torchtitan/utils.py index d976003d5..c77685e96 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -272,8 +272,10 @@ def init_distributed(job_config): ) -def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: +def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False, exclude_frozen: bool = False) -> int: num_params = sum(p.numel() for p in model.parameters()) + if exclude_frozen: + num_params -= sum(p.numel() for p in model.parameters() if not p.requires_grad) if exclude_embedding: num_params -= sum(p.numel() for p in model.tok_embeddings.parameters()) return num_params diff --git a/train.py b/train.py index ced765087..ad0b9ee23 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from operator import attrgetter import os import time from datetime import timedelta @@ -105,6 +106,22 @@ def main(job_config: JobConfig): ) with torch.device("meta"): model = model_cls.from_model_args(model_config) + for module_name in job_config.model.frozen_modules: + try: + module = attrgetter(module_name)(model) + num_params = 0 + for param in module.parameters(): + param.requires_grad = False + num_params += param.numel() + logger.info( + f"{color.red}Freezing {num_params:,} parameters in {color.magenta}model.{module_name}.{color.reset}" + ) + except AttributeError: + logger.warning( + f"""Module {color.magenta}{module_name}{color.reset} is set to be frozen but it does not exist in the model. + Make sure the module name is a valid submodule like `tok_embeddings, layers.0.attention` etc. + """ + ) # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -112,6 +129,11 @@ def main(job_config: JobConfig): # log model size model_param_count = utils.get_num_params(model) + model_trainable_param_count = utils.get_num_params(model, exclude_frozen=True) + if model_trainable_param_count != model_param_count: + trainable_log_str = f"{color.green}of which {model_trainable_param_count:,} parameters are trainable" + else: + trainable_log_str = f"{color.green}all are trainable." num_flop_per_token = utils.get_num_flop_per_token( utils.get_num_params(model, exclude_embedding=True), model_config, @@ -119,7 +141,7 @@ def main(job_config: JobConfig): ) logger.info( f"{color.blue}Model {train_spec.name} {job_config.model.flavor} " - f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + f"{color.red}size: {model_param_count:,} total parameters {trainable_log_str}{color.reset}" ) # loss function to be shared by Pipeline Parallel and SPMD training diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 8f4a40dd6..974d25c0e 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -27,6 +27,7 @@ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./tests/assets/test_tiktoken.model" # converters = "float8" +# frozen_modules = "tok_embeddings,layers.0.attention" [optimizer] name = "AdamW"