From 4ab4224a96d56a68966a9428bd12819a660dcea5 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 19 Mar 2021 16:51:22 +0530 Subject: [PATCH 1/2] Fix save pretrained for TPUs --- aitextgen/train.py | 51 +++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index f7f6fc4..94be073 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -1,14 +1,17 @@ -import pytorch_lightning as pl -from pytorch_lightning.callbacks.progress import ProgressBarBase -from tqdm.auto import tqdm +import os +import shutil +import subprocess import sys + import torch from torch.optim import AdamW from torch.utils.data import DataLoader +from tqdm.auto import tqdm from transformers import get_linear_schedule_with_warmup -import os -import shutil -import subprocess + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.progress import ProgressBarBase +from pytorch_lightning.utilities import _TPU_AVAILABLE class ATGTransformer(pl.LightningModule): @@ -18,12 +21,12 @@ class ATGTransformer(pl.LightningModule): def __init__(self, model, dataset, hparams, tokenizer): super(ATGTransformer, self).__init__() - self.model, self.dataset, self.hparams, self.tokenizer = ( + self.model, self.dataset, self.tokenizer = ( model, dataset, - hparams, tokenizer, ) + self.save_hyperparameters(hparams) def forward(self, inputs): return self.model(**inputs, return_dict=False) @@ -112,6 +115,7 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only self.num_layers_freeze = num_layers_freeze + self.save_every_check = self.save_every > 0 and self.steps % self.save_every == 0 def enabled(self): self.enabled = True @@ -172,10 +176,19 @@ def on_batch_end(self, trainer, pl_module): desc += f" — GPU Mem: {gpu_memory} MB" self.main_progress_bar.update(self.progress_bar_refresh_rate) self.main_progress_bar.set_description(desc) - + + if _TPU_AVAILABLE and self.save_every_check: + did_unfreeze = False + if self.enabled: + self.unfreeze_layers(pl_module) + did_unfreeze = True + self.save_pytorch_model(trainer, pl_module, tpu=True) + if did_unfreeze: + self.freeze_layers(pl_module) + if self.enabled: did_unfreeze = False - if self.save_every > 0 and self.steps % self.save_every == 0: + if not _TPU_AVAILABLE and self.save_every_check: self.unfreeze_layers(pl_module) self.save_pytorch_model(trainer, pl_module) did_unfreeze = True @@ -228,13 +241,19 @@ def generate_sample_text(self, trainer, pl_module): self.main_progress_bar.write("=" * 10) - def save_pytorch_model(self, trainer, pl_module): - self.main_progress_bar.write( - f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" - ) - pl_module.model.save_pretrained(self.output_dir) + def save_pytorch_model(self, trainer, pl_module, tpu=False): + + if self.enabled: + self.main_progress_bar.write( + f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m" + ) + if tpu: + import torch_xla.core.xla_model as xm + pl_module.model.save_pretrained(self.output_dir, save_function=xm.save) + else: + pl_module.model.save_pretrained(self.output_dir) - if self.save_gdrive: + if self.enabled and self.save_gdrive: for pt_file in ["pytorch_model.bin", "config.json"]: shutil.copyfile( os.path.join(self.output_dir, pt_file), From e9ac59826adc31b89d08c63e397f65c20339210a Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sat, 20 Mar 2021 02:58:32 +0530 Subject: [PATCH 2/2] add property for save_every_check --- aitextgen/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 94be073..23baa21 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -115,7 +115,10 @@ def __init__( self.progress_bar_refresh_rate = progress_bar_refresh_rate self.train_transformers_only = train_transformers_only self.num_layers_freeze = num_layers_freeze - self.save_every_check = self.save_every > 0 and self.steps % self.save_every == 0 + + @property + def save_every_check(self): + return self.save_every > 0 and self.steps % self.save_every == 0 def enabled(self): self.enabled = True