Skip to content

Commit

Permalink
Put test set generation in 'on_validation_end', otherwise it gets ran…
Browse files Browse the repository at this point in the history
… for each validation batch
  • Loading branch information
mitchelldehaven committed Apr 23, 2024
1 parent d8655a1 commit e6af4c6
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,29 +282,32 @@ def training_step_d(self, batch: Batch):
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
self.log("val_loss", val_loss)

# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return val_loss

def on_validation_end(self) -> None:
# Generate audio examples after validation, but not during sanity check
if not self.trainer.sanity_checking:
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return super().on_validation_end()

def configure_optimizers(self):
optimizers = [
torch.optim.AdamW(
Expand Down

0 comments on commit e6af4c6

Please sign in to comment.