From 5f24e8c8eaff74b83b81c88f437dba5a58e05f75 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 6 May 2024 14:24:31 +0200 Subject: [PATCH 1/2] train the T5 text encoder as well, options for LoRA rank, learning rate and early stopping --- train_scripts/train_pixart_lora_hf.py | 221 ++++++++++++++++++++++---- 1 file changed, 190 insertions(+), 31 deletions(-) diff --git a/train_scripts/train_pixart_lora_hf.py b/train_scripts/train_pixart_lora_hf.py index f62f0aa..9f6e4b8 100644 --- a/train_scripts/train_pixart_lora_hf.py +++ b/train_scripts/train_pixart_lora_hf.py @@ -54,8 +54,19 @@ logger = get_logger(__name__, log_level="INFO") - -def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): +def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + if not isinstance(model, list): + model = [model] + for m in model: + for param in m.parameters(): + # only upcast trainable parameters into fp32 + if param.requires_grad: + param.data = param.to(dtype) + +def get_trainable_parameters(optimizer: torch.optim.Optimizer) -> int: + return sum(sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups) + +def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None, train_text_encoder=False): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -77,6 +88,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, model_card = f""" # LoRA text2image fine-tuning - {repo_id} These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +LoRA for text encoder was enabled: {train_text_encoder} \n {img_str} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: @@ -272,6 +284,28 @@ def parse_args(): help="Whether or not to use RS Lora. For more information, see" " https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig.use_rslora" ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether or not to also train the text encoder") + parser.add_argument( + "--text_encoder_lora_rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices for the text encoder training."), + ) + parser.add_argument( + "--text_encoder_learning_rate", + type=float, + default=None, + help="learning rate for text encoder trainer, default is same as learning_rate", + ) + parser.add_argument( + "--text_encoder_stop_at_percentage_steps", + type=float, + default=1.0, + help="the percentage of the total training steps at which the training of the text encoder should be halted. 1.0 means train for all steps.", + ) parser.add_argument( "--allow_tf32", action="store_true", @@ -502,6 +536,9 @@ def main(): for param in transformer.parameters(): param.requires_grad_(False) + # Move transformer, vae and text_encoder to device and cast to weight_dtype + transformer.to(accelerator.device) + lora_config = LoraConfig( r=args.rank, init_lora_weights="gaussian", @@ -523,25 +560,49 @@ def main(): use_rslora=args.use_rslora ) - # Move transformer, vae and text_encoder to device and cast to weight_dtype - transformer.to(accelerator.device) - - def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): - if not isinstance(model, list): - model = [model] - for m in model: - for param in m.parameters(): - # only upcast trainable parameters into fp32 - if param.requires_grad: - param.data = param.to(dtype) - transformer = get_peft_model(transformer, lora_config) if args.mixed_precision == "fp16": # only upcast trainable parameters (LoRA) into fp32 cast_training_params(transformer, dtype=torch.float32) + accelerator.print("Transformer:") transformer.print_trainable_parameters() + if args.train_text_encoder: + if not 0 < args.text_encoder_stop_at_percentage_steps <= 1: + args.text_encoder_stop_at_percentage_steps = 1 + + if args.gradient_checkpointing: + # this needs to be done before adding the LoRA layers + # otherwise, enabling gradient checkpointing for the text encoder will generate the warning: + # "UserWarning: None of the inputs have requires_grad=True. Gradients will be None" + # more info: + # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2235 + text_encoder.gradient_checkpointing_enable() + + # prepare the text_encoder for LoRA + lora_config_for_text_encoder = LoraConfig( + init_lora_weights="gaussian", + r=args.text_encoder_lora_rank, + # lora_alpha=args. ..., + # the dropout probability of the LoRA layers + lora_dropout=0.01, + target_modules=["k","q","v","o"] + ) + + text_encoder=get_peft_model(text_encoder, lora_config_for_text_encoder) + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(text_encoder, dtype=torch.float32) + + accelerator.print("\033[91m") + accelerator.print("IMPORTANT !! Training the Text Encoder in fp16 might lead to NaNs in step_loss, if it does please use fp32 or bf16 for training the Text Encoder.") + accelerator.print(" more info: \n https://github.com/huggingface/transformers/issues/4586#issuecomment-639704855 \n https://github.com/huggingface/transformers/issues/17978#issuecomment-1173761651") + accelerator.print("\033[0m") + + accelerator.print("Text Encoder:") + text_encoder.print_trainable_parameters() + # 10. Handle saving and loading of checkpoints # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): @@ -550,9 +611,19 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_ = accelerator.unwrap_model(transformer) lora_state_dict = get_peft_model_state_dict(transformer_, adapter_name="default") - StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora"), lora_state_dict) + + text_encoder_to_save = None + if args.train_text_encoder: + text_encoder_ = accelerator.unwrap_model(text_encoder) + text_encoder_to_save = get_peft_model_state_dict(text_encoder_) + + StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora_weights"), lora_state_dict, + text_encoder_lora_layers=text_encoder_to_save) + # save weights in peft format to be able to load them back - transformer_.save_pretrained(output_dir) + transformer_.save_pretrained(os.path.join(output_dir, "transformer")) + if args.train_text_encoder: + text_encoder_.save_pretrained(os.path.join(output_dir, "text_encoder")) for _, model in enumerate(models): # make sure to pop weight so that corresponding model is not saved again @@ -563,6 +634,14 @@ def load_model_hook(models, input_dir): transformer_ = accelerator.unwrap_model(transformer) transformer_.load_adapter(input_dir, "default", is_trainable=True) + # raulc0399: is seems that this hook is not implemented!! + # check https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py#L701 + # todo: load and use transformer and text_encoder + cast training params + # if args.mixed_precision == "fp16": + # # only upcast trainable parameters (LoRA) into fp32 + # cast_training_params(transformer_, dtype=torch.float32) + accelerator.print("load_model_hook NOT IMPLEMENTED!!!") + for _ in range(len(models)): # pop models so that they are not loaded again models.pop() @@ -585,7 +664,24 @@ def load_model_hook(models, input_dir): else: raise ValueError("xformers is not available. Make sure it is installed correctly") - lora_layers = filter(lambda p: p.requires_grad, transformer.parameters()) + # transformer params to optimize + params_to_optimize = list(filter(lambda p: p.requires_grad, transformer.parameters())) + params_to_clip = params_to_optimize + + if args.train_text_encoder: + text_encoder_params_to_optimize = list(filter(lambda p: p.requires_grad, text_encoder.parameters())) + params_to_clip = params_to_optimize + text_encoder_params_to_optimize + + # transformer and text encoder have the same learning rate + if args.text_encoder_learning_rate is None: + params_to_optimize = ( + params_to_optimize + text_encoder_params_to_optimize + ) + else: + params_to_optimize = [ + {"params": params_to_optimize, "lr": args.learning_rate}, + {"params": text_encoder_params_to_optimize, "lr": args.text_encoder_learning_rate}, + ] # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -612,13 +708,15 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - lora_layers, + params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) + accelerator.print(f"Total of trainable parameters: {get_trainable_parameters(optimizer):,}") + # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -744,6 +842,9 @@ def collate_fn(examples): transformer, optimizer, train_dataloader, lr_scheduler = \ accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler) + if args.train_text_encoder: + text_encoder = accelerator.prepare(text_encoder) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: @@ -766,6 +867,13 @@ def collate_fn(examples): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") + + if args.train_text_encoder: + training_rate_text_encoder = args.text_encoder_learning_rate if args.text_encoder_learning_rate is not None else f"{args.learning_rate} same as transformer" + logger.info(f" Training text encoder with rank {args.text_encoder_lora_rank}, learing rate {training_rate_text_encoder}") + if args.text_encoder_stop_at_percentage_steps < 1: + logger.info(f" Stop training text encoder at {args.text_encoder_stop_at_percentage_steps * 100}% of total training steps") + global_step = 0 first_epoch = 0 @@ -804,11 +912,16 @@ def collate_fn(examples): disable=not accelerator.is_local_main_process, ) + models_for_accumulate = [transformer, text_encoder] if args.train_text_encoder else transformer + for epoch in range(first_epoch, args.num_train_epochs): transformer.train() + if args.train_text_encoder: + text_encoder.train() + train_loss = 0.0 for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(transformer): + with accelerator.accumulate(models_for_accumulate): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor @@ -885,7 +998,6 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = lora_layers accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() @@ -923,14 +1035,21 @@ def collate_fn(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) - unwrapped_transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False) - transformer_lora_state_dict = get_peft_model_state_dict(unwrapped_transformer) + # raulc0399: not needed since they are already saved in save_state hook + # unwrapped_transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False) + # transformer_lora_state_dict = get_peft_model_state_dict(unwrapped_transformer) - StableDiffusionPipeline.save_lora_weights( - save_directory=save_path, - unet_lora_layers=transformer_lora_state_dict, - safe_serialization=True, - ) + # text_encoder_to_save = None + # if args.train_text_encoder: + # text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + # text_encoder_to_save = get_peft_model_state_dict(text_encoder_) + + # StableDiffusionPipeline.save_lora_weights( + # save_directory=os.path.join(save_path, "transformer_lora_weights_checkpoint"), + # unet_lora_layers=transformer_lora_state_dict, + # text_encoder_lora_layers=text_encoder_to_save, + # safe_serialization=True, + # ) logger.info(f"Saved state to {save_path}") @@ -940,17 +1059,36 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break + if args.train_text_encoder and args.text_encoder_stop_at_percentage_steps < 1 and global_step >= args.max_train_steps * args.text_encoder_stop_at_percentage_steps: + accelerator.print("\033[91mFreezing text encoder...") + + accelerator.print(f"Number of trainable parameters before freeze: {get_trainable_parameters(optimizer):,}") + + text_encoder.zero_grad() + text_encoder.requires_grad_(False) + args.train_text_encoder = False + + text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + text_encoder.save_pretrained(os.path.join(args.output_dir, "text_encoder")) + + accelerator.print(f"Number of trainable parameters after freeze: {get_trainable_parameters(optimizer):,}") + accelerator.print("\033[0m") + if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + + text_encoder_for_generation = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) if args.train_text_encoder else text_encoder + # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False), - text_encoder=text_encoder, vae=vae, + text_encoder=text_encoder_for_generation, + vae=vae, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -986,9 +1124,18 @@ def collate_fn(examples): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False) - transformer.save_pretrained(args.output_dir) + transformer.save_pretrained(os.path.join(args.output_dir, "transformer")) lora_state_dict = get_peft_model_state_dict(transformer) - StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "transformer_lora"), lora_state_dict) + + if args.train_text_encoder: + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + text_encoder_.save_pretrained(os.path.join(args.output_dir, "text_encoder")) + text_encoder_to_save = get_peft_model_state_dict(text_encoder_) + else: + text_encoder_to_save = None + + StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "transformer_lora_weights"), lora_state_dict, + text_encoder_lora_layers=text_encoder_to_save) if args.push_to_hub: save_model_card( @@ -997,6 +1144,7 @@ def collate_fn(examples): base_model=args.pretrained_model_name_or_path, dataset_name=args.dataset_name, repo_folder=args.output_dir, + train_text_encoder=args.train_text_encoder, ) upload_folder( repo_id=repo_id, @@ -1011,7 +1159,15 @@ def collate_fn(examples): args.pretrained_model_name_or_path, subfolder='transformer', torch_dtype=weight_dtype ) # load lora weight - transformer = PeftModel.from_pretrained(transformer, args.output_dir) + transformer = PeftModel.from_pretrained(transformer, os.path.join(args.output_dir, "transformer")) + + if args.train_text_encoder: + # Load previous text_encoder + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=weight_dtype + ) + text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(args.output_dir, "text_encoder")) + # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder, vae=vae, @@ -1019,6 +1175,9 @@ def collate_fn(examples): ) pipeline = pipeline.to(accelerator.device) + if args.train_text_encoder: + del text_encoder + del transformer torch.cuda.empty_cache() From 6bbd7295dd9e290ca1540d8cb39d69f3fa6a6089 Mon Sep 17 00:00:00 2001 From: raci0399 Date: Mon, 6 May 2024 14:47:29 +0200 Subject: [PATCH 2/2] rm text encoder from params to clip and models to accumulate, after the training for text encoder has stopped --- train_scripts/train_pixart_lora_hf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train_scripts/train_pixart_lora_hf.py b/train_scripts/train_pixart_lora_hf.py index 9f6e4b8..a5ab797 100644 --- a/train_scripts/train_pixart_lora_hf.py +++ b/train_scripts/train_pixart_lora_hf.py @@ -1066,6 +1066,9 @@ def collate_fn(examples): text_encoder.zero_grad() text_encoder.requires_grad_(False) + params_to_clip = list(filter(lambda p: p.requires_grad, transformer.parameters())) + models_for_accumulate = transformer + args.train_text_encoder = False text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)