Fastspeech2 代码阅读笔记——模型训练[通俗易懂]

Fastspeech2 代码阅读笔记——模型训练[通俗易懂]1.model/loss.pyFastSpeech2在训练时会对duration predictor、pitch predictor和energ

1.model/loss.py

FastSpeech2在训练时会对duration predictor、pitch predictor和energy predictor同时训练,结合之前自回归模型均会对最后mel经过postnet处理的前后计算损失,故训练过程中会计算五个损失。loss.py文件中就定义了损失类

 import torch
 import torch.nn as nn
 
 
 class FastSpeech2Loss(nn.Module):
     """ FastSpeech2 Loss """
 
     # 自定义的损失,整个模型的损失由五个不同损失组成
     # 分别时是mel_loss,postnet_mel_loss,duration_loss,pitch_loss,energy_loss
     def __init__(self, preprocess_config, model_config):
         super(FastSpeech2Loss, self).__init__()
         self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
         self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
         self.mse_loss = nn.MSELoss()
         self.mae_loss = nn.L1Loss()
 
     def forward(self, inputs, predictions):
         # 目标,相当于label
         (mel_targets, _, _, pitch_targets, energy_targets, duration_targets,) = inputs[6:]
         # 模型的输出
         (mel_predictions,postnet_mel_predictions, pitch_predictions, energy_predictions, log_duration_predictions,
          _, src_masks, mel_masks, _, _,) = predictions
         src_masks = ~src_masks
         mel_masks = ~mel_masks
         log_duration_targets = torch.log(duration_targets.float() + 1)  # 对目标持续时间取log
         mel_targets = mel_targets[:, : mel_masks.shape[1], :]
         mel_masks = mel_masks[:, :mel_masks.shape[1]]
 
         log_duration_targets.requires_grad = False
         pitch_targets.requires_grad = False
         energy_targets.requires_grad = False
         mel_targets.requires_grad = False
 
         if self.pitch_feature_level == "phoneme_level":
             pitch_predictions = pitch_predictions.masked_select(src_masks)
             pitch_targets = pitch_targets.masked_select(src_masks)
         elif self.pitch_feature_level == "frame_level":
             pitch_predictions = pitch_predictions.masked_select(mel_masks)
             pitch_targets = pitch_targets.masked_select(mel_masks)
 
         if self.energy_feature_level == "phoneme_level":
             energy_predictions = energy_predictions.masked_select(src_masks)
             energy_targets = energy_targets.masked_select(src_masks)
         if self.energy_feature_level == "frame_level":
             energy_predictions = energy_predictions.masked_select(mel_masks)
             energy_targets = energy_targets.masked_select(mel_masks)
 
         log_duration_predictions = log_duration_predictions.masked_select(src_masks)
         log_duration_targets = log_duration_targets.masked_select(src_masks)
 
         mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
         postnet_mel_predictions = postnet_mel_predictions.masked_select(mel_masks.unsqueeze(-1))
         mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
 
         # 5个loss
         # 解码器预测的mel谱图的损失
         mel_loss = self.mae_loss(mel_predictions, mel_targets)
         # 解码器预测的mel谱图经过postnet处理后的损失
         postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
         # pitch loss
         pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
         # energy loss
         energy_loss = self.mse_loss(energy_predictions, energy_targets)
         # duration loss
         duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
 
         total_loss = (
             mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
         )
 
         return (
             total_loss,
             mel_loss,
             postnet_mel_loss,
             pitch_loss,
             energy_loss,
             duration_loss,
         )
 

代码100分

2.model/optimizer.py

该文件中封装了一个学习率优化类,其可以实现学习率动态变化,结合了退火处理

代码100分 import torch
 import numpy as np
 
 # 为学习率更新封装的类
 class ScheduledOptim:
     """ A simple wrapper class for learning rate scheduling """
 
     def __init__(self, model, train_config, model_config, current_step):
 
         self._optimizer = torch.optim.Adam(
             model.parameters(),
             betas=train_config["optimizer"]["betas"],  # betas: [0.9, 0.98]
             eps=train_config["optimizer"]["eps"],  # eps: 0.000000001
             weight_decay=train_config["optimizer"]["weight_decay"],  # weight_decay: 0.0
         )
         self.n_warmup_steps = train_config["optimizer"]["warm_up_step"]  # warmup步数: 4000
         self.anneal_steps = train_config["optimizer"]["anneal_steps"]  # 退火步数: [300000, 400000, 500000]
         self.anneal_rate = train_config["optimizer"]["anneal_rate"]  # 退火率: 0.3
         self.current_step = current_step  # 训练的当前步骤
         self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5)  # 初始学习率
 
     # 使用设置的学习率方案进行参数更新
     def step_and_update_lr(self):
         self._update_learning_rate()
         self._optimizer.step()
 
     # 清除梯度
     def zero_grad(self):
         # print(self.init_lr)
         self._optimizer.zero_grad()
 
     # 加载保存的优化器参数
     def load_state_dict(self, path):
         self._optimizer.load_state_dict(path)
 
     # 学习率变化规则
     def _get_lr_scale(self):
         lr = np.min(
             [
                 # np.power(x,y) 返回x的y次方
                 np.power(self.current_step, -0.5),
                 np.power(self.n_warmup_steps, -1.5) * self.current_step,
             ]
         )
         for s in self.anneal_steps:
             # 如果当前训练步数大于设置的回火步数,进一步对学习率进行设置
             if self.current_step > s:
                 lr = lr * self.anneal_rate
         return lr
 
     # 该学习方案中每步的学习率
     def _update_learning_rate(self):
         """ Learning rate scheduling per step """
         self.current_step += 1
         # 计算当前步数的学习率
         lr = self.init_lr * self._get_lr_scale()
         # 给所有参数设置学习率
         for param_group in self._optimizer.param_groups:
             param_group["lr"] = lr
 

3.dataset.py

该文件主要用于数据加载和数据转换,将预处理好的文本音素、时序时间、mel谱图、pitch序列和energy序列等数据转换、加载为模型可以直接使用的形式。

 import json
 import math
 import os
 
 import numpy as np
 from torch.utils.data import Dataset
 
 from text import text_to_sequence
 from utils.tools import pad_1D, pad_2D
 
 
 class Dataset(Dataset):
     def __init__(self, filename, preprocess_config, train_config, sort=False, drop_last=False):
         self.dataset_name = preprocess_config["dataset"]  # LibriTTS
         self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]  # "./preprocessed_data/LibriTTS"
         self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]  # ["english_cleaners"]
         self.batch_size = train_config["optimizer"]["batch_size"]  # 16
 
         self.basename, self.speaker, self.text, self.raw_text = self.process_meta(filename)
         with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
             self.speaker_map = json.load(f)
         self.sort = sort
         self.drop_last = drop_last
 
     def __len__(self):
         return len(self.text)
 
     def __getitem__(self, idx):  # 通过下标索引获取数据
         basename = self.basename[idx]  # 文件的basaname
         speaker = self.speaker[idx]  # speaker名称,即数据集的名称
         speaker_id = self.speaker_map[speaker]  # speaker对应的数值序号
         raw_text = self.raw_text[idx]  # 原始文本
         phone = np.array(text_to_sequence(self.text[idx], self.cleaners))  # 文本处理后的音素序列
         mel_path = os.path.join(
             self.preprocessed_path,
             "mel",
             "{}-mel-{}.npy".format(speaker, basename),
         )
         mel = np.load(mel_path)  # 加载mel频谱图
         pitch_path = os.path.join(
             self.preprocessed_path,
             "pitch",
             "{}-pitch-{}.npy".format(speaker, basename),
         )
         pitch = np.load(pitch_path)  # 加载pitch序列
         energy_path = os.path.join(
             self.preprocessed_path,
             "energy",
             "{}-energy-{}.npy".format(speaker, basename),
         )
         energy = np.load(energy_path)  # 加载energy序列
         duration_path = os.path.join(
             self.preprocessed_path,
             "duration",
             "{}-duration-{}.npy".format(speaker, basename),
         )
         duration = np.load(duration_path)  # 加载持续时间
 
         sample = {
             "id": basename,
             "speaker": speaker_id,
             "text": phone,
             "raw_text": raw_text,
             "mel": mel,
             "pitch": pitch,
             "energy": energy,
             "duration": duration,
         }
 
         return sample  # 返回数据
 
     # 加载每个音频对应的文本数据
     def process_meta(self, filename):
         with open(
             os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
         ) as f:
             name = []
             speaker = []
             text = []
             raw_text = []
             for line in f.readlines():
                 n, s, t, r = line.strip("\n").split("|")
                 name.append(n)
                 speaker.append(s)
                 text.append(t)
                 raw_text.append(r)
             return name, speaker, text, raw_text
 
     # 对数据进一步转换
     def reprocess(self, data, idxs):
         ids = [data[idx]["id"] for idx in idxs]
         speakers = [data[idx]["speaker"] for idx in idxs]
         texts = [data[idx]["text"] for idx in idxs]
         raw_texts = [data[idx]["raw_text"] for idx in idxs]
         mels = [data[idx]["mel"] for idx in idxs]
         pitches = [data[idx]["pitch"] for idx in idxs]
         energies = [data[idx]["energy"] for idx in idxs]
         durations = [data[idx]["duration"] for idx in idxs]
 
         text_lens = np.array([text.shape[0] for text in texts])  # 文本序列长度列表
         mel_lens = np.array([mel.shape[0] for mel in mels])  # mel图谱序列长度列表
 
         speakers = np.array(speakers)
         # 对以下的序列进行对应维度的pad
         texts = pad_1D(texts)
         mels = pad_2D(mels)
         pitches = pad_1D(pitches)
         energies = pad_1D(energies)
         durations = pad_1D(durations)
 
         return (
             ids,
             raw_texts,
             speakers,
             texts,
             text_lens,
             max(text_lens),
             mels,
             mel_lens,
             max(mel_lens),
             pitches,
             energies,
             durations,
         )
 
     # 定义数据集时使用的数据转换回调函数
     def collate_fn(self, data):
         data_size = len(data)
 
         if self.sort:  # 如果排序
             len_arr = np.array([d["text"].shape[0] for d in data])
             idx_arr = np.argsort(-len_arr)  # 返回文本序列长度从大到小排序的索引序列
         else:
             idx_arr = np.arange(data_size)
         # 当一个batch传入的数据量不是batch_size的整数倍时,tail就是最后不够一个batch_size的数据
         tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :]
         # 前面batch_size的整数倍数据对应的序列列表
         idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
         idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
         # 如果不删除最后剩下的tail部分,并且tail不为空
         if not self.drop_last and len(tail) > 0:
             # 将tail的索引序列添加
             idx_arr += [tail.tolist()]
 
         output = list()
         for idx in idx_arr:
             # 调用reprocess函数进一步对数据转化,主要是进行pad操作
             output.append(self.reprocess(data, idx))
 
         return output
 
 # 用于语音合成时构建推理的数据集类,主要步骤基本一致,因为只需处理文本部分,故少了处理音频文件的部分代码
 class TextDataset(Dataset):
     def __init__(self, filepath, preprocess_config):
         self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]  # ["english_cleaners"]
 
         self.basename, self.speaker, self.text, self.raw_text = self.process_meta(filepath)
         with open(
             os.path.join(
                 preprocess_config["path"]["preprocessed_path"], "speakers.json"
             )
         ) as f:
             self.speaker_map = json.load(f)
 
     def __len__(self):
         return len(self.text)
 
     def __getitem__(self, idx):
         basename = self.basename[idx]
         speaker = self.speaker[idx]
         speaker_id = self.speaker_map[speaker]
         raw_text = self.raw_text[idx]
         phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
 
         return (basename, speaker_id, phone, raw_text)
 
     def process_meta(self, filename):
         with open(filename, "r", encoding="utf-8") as f:
             name = []
             speaker = []
             text = []
             raw_text = []
             for line in f.readlines():
                 n, s, t, r = line.strip("\n").split("|")
                 name.append(n)
                 speaker.append(s)
                 text.append(t)
                 raw_text.append(r)
             return name, speaker, text, raw_text
 
     def collate_fn(self, data):
         ids = [d[0] for d in data]
         speakers = np.array([d[1] for d in data])
         texts = [d[2] for d in data]
         raw_texts = [d[3] for d in data]
         text_lens = np.array([text.shape[0] for text in texts])
 
         texts = pad_1D(texts)
 
         return ids, raw_texts, speakers, texts, text_lens, max(text_lens)
 
 
 if __name__ == "__main__":
     # Test
     import torch
     import yaml
     from torch.utils.data import DataLoader
     from utils.utils import to_device
 
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     preprocess_config = yaml.load(open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader)
     train_config = yaml.load(open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader)
 
     train_dataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True)
     val_dataset = Dataset("val.txt", preprocess_config, train_config, sort=False, drop_last=False)
     train_loader = DataLoader(
         train_dataset,
         batch_size=train_config["optimizer"]["batch_size"] * 4, # 16 * 4
         shuffle=True,
         collate_fn=train_dataset.collate_fn,
     )
     val_loader = DataLoader(
         val_dataset,
         batch_size=train_config["optimizer"]["batch_size"],
         shuffle=False,
         collate_fn=val_dataset.collate_fn,
     )
 
     n_batch = 0
     for batchs in train_loader:
         for batch in batchs:
             to_device(batch, device)
             n_batch += 1
     print(
         "Training set  with size {} is composed of {} batches.".format(
             len(train_dataset), n_batch
         )
     )
 
     n_batch = 0
     for batchs in val_loader:
         for batch in batchs:
             to_device(batch, device)
             n_batch += 1
     print(
         "Validation set  with size {} is composed of {} batches.".format(
             len(val_dataset), n_batch
         )
     )

4.train.py

该文件是FastSpeech模型训练过程实现代码,整体流程与普通模型训练一样,需要注意的一点就是数据划分过程中,是分成了一个大batch,其中包含数个real batch,故训练过程在正常的两个for循环嵌套外是一个“while True”的训练,其不是基于epoch来判断训练终止,而是当total_step达到了设置了训练步数才终止训练

代码100分 import argparse
 import os
 
 import torch
 import yaml
 import torch.nn as nn
 from torch.utils.data import DataLoader
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
 from utils.model import get_model, get_vocoder, get_param_num
 from utils.tools import to_device, log, synth_one_sample
 from model import FastSpeech2Loss
 from dataset import Dataset
 
 from evaluate import evaluate
 
 # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
 
 
 def main(args, configs):
     print("Prepare training ...")
     # 加载预处理、模型和训练的配置文件
     preprocess_config, model_config, train_config = configs
 
     # Get dataset
     # 加载训练数据集
     dataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True)
     batch_size = train_config["optimizer"]["batch_size"] # batch_size = 16
     group_size = 4  # Set this larger than 1 to enable sorting in Dataset
     assert batch_size * group_size < len(dataset)
     loader = DataLoader(
         dataset,
         batch_size=batch_size * group_size,  # 16*4
         shuffle=True,
         collate_fn=dataset.collate_fn,
     )
 
     # Prepare model
     model, optimizer = get_model(args, configs, device, train=True)  # 加载模型和优化器
     model = nn.DataParallel(model)
     num_param = get_param_num(model)  # 计算模型参数量
     Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)  # 定义损失函数
     print("Number of FastSpeech2 Parameters:", num_param)
 
     # Load vocoder
     vocoder = get_vocoder(model_config, device)  # 加载vocoder
 
     # Init logger
     for p in train_config["path"].values():
         os.makedirs(p, exist_ok=True)
     # log_path "./output/log/LibriTTS" /train or /val
     train_log_path = os.path.join(train_config["path"]["log_path"], "train")
     val_log_path = os.path.join(train_config["path"]["log_path"], "val")
     os.makedirs(train_log_path, exist_ok=True)
     os.makedirs(val_log_path, exist_ok=True)
     # 使用tensorboard记录训练过程
     train_logger = SummaryWriter(train_log_path)
     val_logger = SummaryWriter(val_log_path)
 
     # Training
     step = args.restore_step + 1  # 当前步数
     epoch = 1
     grad_acc_step = train_config["optimizer"]["grad_acc_step"]  # 梯度累步数值
     grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]  # 梯度剪裁的值
     total_step = train_config["step"]["total_step"]  # 总的训练步数
     log_step = train_config["step"]["log_step"]  # 100
     save_step = train_config["step"]["save_step"]  # 100000
     synth_step = train_config["step"]["synth_step"]  # 1000
     val_step = train_config["step"]["val_step"]  # 1000
 
     outer_bar = tqdm(total=total_step, desc="Training", position=0)  # 显示所有步数的运行情况
     outer_bar.n = args.restore_step  # 加载之前已经训练完的步数
     outer_bar.update()
 
     while True:
         # 显示当前epoch内的训练步数情况
         inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
         for batchs in loader:
             for batch in batchs:  # 根据前面的设置,一个batchs中是有group_size个batch的
                 batch = to_device(batch, device)
 
                 # Forward
                 output = model(*(batch[2:]))
 
                 # Cal Loss
                 losses = Loss(batch, output)  # 计算损失
                 total_loss = losses[0]  # 总损失
 
                 # Backward
                 total_loss = total_loss / grad_acc_step
                 total_loss.backward()
                 # 到了梯度累计释放的步数
                 if step % grad_acc_step == 0:
                     # Clipping gradients to avoid gradient explosion
                     # 梯度剪裁
                     nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)
 
                     # Update weights
                     optimizer.step_and_update_lr()
                     optimizer.zero_grad()
                 # 到了记录的步数,写日志
                 if step % log_step == 0:
                     losses = [l.item() for l in losses]
                     message1 = "Step {}/{}, ".format(step, total_step)
                     message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
                         *losses
                     )
 
                     with open(os.path.join(train_log_path, "log.txt"), "a") as f:
                         f.write(message1 + message2 + "\n")
                     # 将日志信息在进度掉的后面显示
                     outer_bar.write(message1 + message2)
                     # 调用定义的日志函数在tensorboard中记录信息
                     log(train_logger, step, losses=losses)
 
                 # 到了合成音频的步数
                 if step % synth_step == 0:
                     fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
                         batch,
                         output,
                         vocoder,
                         model_config,
                         preprocess_config,
                     )
 
                     log(train_logger, fig=fig,tag="Training/step_{}_{}".format(step, tag),)
                     sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]  # 22050
                     # 记录以target_mel谱图使用vocoder重构的音频
                     log(
                         train_logger,
                         audio=wav_reconstruction,
                         sampling_rate=sampling_rate,
                         tag="Training/step_{}_{}_reconstructed".format(step, tag),
                     )
                     # 记录以生成的prediction_mel谱图使用vocoder重构的音频
                     log(
                         train_logger,
                         audio=wav_prediction,
                         sampling_rate=sampling_rate,
                         tag="Training/step_{}_{}_synthesized".format(step, tag),
                     )
 
                 # 到了验证的步数
                 if step % val_step == 0:
                     # 切换验证模式
                     model.eval()
                     message = evaluate(model, step, configs, val_logger, vocoder)
                     with open(os.path.join(val_log_path, "log.txt"), "a") as f:
                         f.write(message + "\n")
                     outer_bar.write(message)
 
                     model.train()  # 退出时设置回训练模式
 
                 # 到了模型保存的步数,每十万步保存一次模型,
                 # 100000.pth.tar,200000.pth.tar,...
                 # 共9个
                 if step % save_step == 0:
                     torch.save(
                         {
                             "model": model.module.state_dict(),
                             "optimizer": optimizer._optimizer.state_dict(),
                         },
                         os.path.join(
                             train_config["path"]["ckpt_path"],
                             "{}.pth.tar".format(step),
                         ),
                     )
 
                 # 如果到了设置的训练总步数,就停止训练
                 if step == total_step:
                     quit()
                 step += 1
                 # 当前epoch每训练一个step也要在outer_bar中更新
                 outer_bar.update(1)
 
             inner_bar.update(1)
         epoch += 1
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--restore_step", type=int, default=0)
     # -p path to preprocess.yaml
     parser.add_argument(
         "-p",
         "--preprocess_config",
         type=str,
         required=True,
         help="path to preprocess.yaml",
     )
     # -m path to model.yaml
     parser.add_argument(
         "-m", "--model_config", type=str, required=True, help="path to model.yaml"
     )
     # -t path to train.yaml
     parser.add_argument(
         "-t", "--train_config", type=str, required=True, help="path to train.yaml"
     )
     args = parser.parse_args()
 
     # Read Config
     preprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)
     model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
     train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
     configs = (preprocess_config, model_config, train_config)
     # 传入config文件路径preprocess_config,model_config,train_config
     main(args, configs)
 

5.utils/tools.py

本文件中定义了诸多数据转换、模型训练等过程中需要使用的辅助函数

 import os
 import json
 
 import torch
 import torch.nn.functional as F
 import numpy as np
 import matplotlib
 from scipy.io import wavfile
 from matplotlib import pyplot as plt
 
 
 matplotlib.use("Agg")
 
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 # 将训练或推理过程时的各类数据传入到对应device
 def to_device(data, device):
     if len(data) == 12:  # 训练时,将dataloader中的数据转入device
         (
             ids,
             raw_texts,
             speakers,
             texts,
             src_lens,
             max_src_len,
             mels,
             mel_lens,
             max_mel_len,
             pitches,
             energies,
             durations,
         ) = data
 
         speakers = torch.from_numpy(speakers).long().to(device)
         texts = torch.from_numpy(texts).long().to(device)
         src_lens = torch.from_numpy(src_lens).to(device)
         mels = torch.from_numpy(mels).float().to(device)
         mel_lens = torch.from_numpy(mel_lens).to(device)
         pitches = torch.from_numpy(pitches).float().to(device)
         energies = torch.from_numpy(energies).to(device)
         durations = torch.from_numpy(durations).long().to(device)
 
         return (
             ids,
             raw_texts,
             speakers,
             texts,
             src_lens,
             max_src_len,
             mels,
             mel_lens,
             max_mel_len,
             pitches,
             energies,
             durations,
         )
 
     if len(data) == 6:   # 推理时,将dataloader中的数据转入device
         (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data
 
         speakers = torch.from_numpy(speakers).long().to(device)
         texts = torch.from_numpy(texts).long().to(device)
         src_lens = torch.from_numpy(src_lens).to(device)
 
         return (ids, raw_texts, speakers, texts, src_lens, max_src_len)
 
 # 定义的tensorboard日志记录函数
 def log(logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag=""):
     if losses is not None:  # 记录训练过程中所有不同的损失
         logger.add_scalar("Loss/total_loss", losses[0], step)
         logger.add_scalar("Loss/mel_loss", losses[1], step)
         logger.add_scalar("Loss/mel_postnet_loss", losses[2], step)
         logger.add_scalar("Loss/pitch_loss", losses[3], step)
         logger.add_scalar("Loss/energy_loss", losses[4], step)
         logger.add_scalar("Loss/duration_loss", losses[5], step)
 
     if fig is not None:   # 记录图片
         logger.add_figure(tag, fig)
 
     if audio is not None:  # 记录音频
         logger.add_audio(tag, audio / max(abs(audio)),sample_rate=sampling_rate,)
 
 # 给整个batch的所有数据生成对应的mask
 def get_mask_from_lengths(lengths, max_len=None):
     batch_size = lengths.shape[0]
     # 如果没有传入最大长度,就以传入batch中序列长度最大的值作为标准
     if max_len is None:
         max_len = torch.max(lengths).item()
     # 先生成一个完整的模板,尺寸是[batci_size, max_len], 其中每一行都是[0, 1, 2, ..., max_len-1]
     ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
     # 此处mask中,序列真实长度对应的位置为False,而超出序列长度的位置为True
     mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
 
     return mask
 
 # 根据持续时间duration调整pitch、energy序列
 def expand(values, durations):
     out = list()
     # zip() return (a[i], b[i])
     for value, d in zip(values, durations):
         # 将序列中对应的value重复d次
         out += [value] * max(0, int(d))
     return np.array(out)
 
 # 训练时合成一个音频样本
 def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config):
 
     basename = targets[0][0]
     src_len = predictions[8][0].item()
     mel_len = predictions[9][0].item()
     mel_target = targets[6][0, :mel_len].detach().transpose(0, 1)
     mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1)
     duration = targets[11][0, :src_len].detach().cpu().numpy()
     if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
         pitch = targets[9][0, :src_len].detach().cpu().numpy()
         pitch = expand(pitch, duration)
     else:
         pitch = targets[9][0, :mel_len].detach().cpu().numpy()
     if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
         energy = targets[10][0, :src_len].detach().cpu().numpy()
         energy = expand(energy, duration)
     else:
         energy = targets[10][0, :mel_len].detach().cpu().numpy()
 
     with open(
         os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
     ) as f:
         stats = json.load(f)
         stats = stats["pitch"] + stats["energy"][:2]
     # 绘制mel谱图
     fig = plot_mel(
         [
             (mel_prediction.cpu().numpy(), pitch, energy),
             (mel_target.cpu().numpy(), pitch, energy),
         ],
         stats,
         ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
     )
     # 加载vocoder
     if vocoder is not None:
         from .model import vocoder_infer
 
         wav_reconstruction = vocoder_infer(
             mel_target.unsqueeze(0),  # 基于gt音频数据抽取的mel谱图重建音频
             vocoder,
             model_config,
             preprocess_config,
         )[0]
         wav_prediction = vocoder_infer(
             mel_prediction.unsqueeze(0),
             vocoder,
             model_config,
             preprocess_config,
         )[0]
     else:
         wav_reconstruction = wav_prediction = None
 
     return fig, wav_reconstruction, wav_prediction, basename
 
 # 批量合成音频
 def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):
 
     basenames = targets[0]
     for i in range(len(predictions[0])):
         basename = basenames[i]
         src_len = predictions[8][i].item()
         mel_len = predictions[9][i].item()
         mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
         duration = predictions[5][i, :src_len].detach().cpu().numpy()
         if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
             pitch = predictions[2][i, :src_len].detach().cpu().numpy()
             pitch = expand(pitch, duration)
         else:
             pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
         if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
             energy = predictions[3][i, :src_len].detach().cpu().numpy()
             energy = expand(energy, duration)
         else:
             energy = predictions[3][i, :mel_len].detach().cpu().numpy()
 
         with open(
             os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
         ) as f:
             stats = json.load(f)
             stats = stats["pitch"] + stats["energy"][:2]
 
         fig = plot_mel(
             [
                 (mel_prediction.cpu().numpy(), pitch, energy),
             ],
             stats,
             ["Synthetized Spectrogram"],
         )
         plt.savefig(os.path.join(path, "{}.png".format(basename)))
         plt.close()
 
     from .model import vocoder_infer
 
     mel_predictions = predictions[1].transpose(1, 2)
     lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
     wav_predictions = vocoder_infer(
         mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
     )
 
     sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
     for wav, basename in zip(wav_predictions, basenames):
         wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav)
 
 
 def plot_mel(data, stats, titles):
     fig, axes = plt.subplots(len(data), 1, squeeze=False)
     if titles is None:
         titles = [None for i in range(len(data))]
     pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
     pitch_min = pitch_min * pitch_std + pitch_mean
     pitch_max = pitch_max * pitch_std + pitch_mean
 
     def add_axis(fig, old_ax):
         ax = fig.add_axes(old_ax.get_position(), anchor="W")
         ax.set_facecolor("None")
         return ax
 
     for i in range(len(data)):
         mel, pitch, energy = data[i]
         pitch = pitch * pitch_std + pitch_mean
         axes[i][0].imshow(mel, origin="lower")
         axes[i][0].set_aspect(2.5, adjustable="box")
         axes[i][0].set_ylim(0, mel.shape[0])
         axes[i][0].set_title(titles[i], fontsize="medium")
         axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
         axes[i][0].set_anchor("W")
 
         ax1 = add_axis(fig, axes[i][0])
         ax1.plot(pitch, color="tomato")
         ax1.set_xlim(0, mel.shape[1])
         ax1.set_ylim(0, pitch_max)
         ax1.set_ylabel("F0", color="tomato")
         ax1.tick_params(
             labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
         )
 
         ax2 = add_axis(fig, axes[i][0])
         ax2.plot(energy, color="darkviolet")
         ax2.set_xlim(0, mel.shape[1])
         ax2.set_ylim(energy_min, energy_max)
         ax2.set_ylabel("Energy", color="darkviolet")
         ax2.yaxis.set_label_position("right")
         ax2.tick_params(
             labelsize="x-small",
             colors="darkviolet",
             bottom=False,
             labelbottom=False,
             left=False,
             labelleft=False,
             right=True,
             labelright=True,
         )
 
     return fig
 
 # pad一维张量
 def pad_1D(inputs, PAD=0):
     def pad_data(x, length, PAD):
         x_padded = np.pad(
             x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
         )
         return x_padded
 
     max_len = max((len(x) for x in inputs))
     padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
 
     return padded
 
 # pad二维张量
 def pad_2D(inputs, maxlen=None):
     def pad(x, max_len):
         PAD = 0
         if np.shape(x)[0] > max_len:
             raise ValueError("not max_len")
 
         s = np.shape(x)[1]
         x_padded = np.pad(
             x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
         )
         return x_padded[:, :s]
 
     if maxlen:
         output = np.stack([pad(x, maxlen) for x in inputs])
     else:
         max_len = max(np.shape(x)[0] for x in inputs)
         output = np.stack([pad(x, max_len) for x in inputs])
 
     return output
 
 # 对长度对齐后的音素序列进行pad
 def pad(input_ele, mel_max_length=None):
     if mel_max_length:
         max_len = mel_max_length
     else:
         max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
 
     out_list = list()
     for i, batch in enumerate(input_ele):  # 此处的一个batch其实是一个音素序列
         if len(batch.shape) == 1:
             #  batch.size(0)即获取音素序列长度
             one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0)
         elif len(batch.shape) == 2:
             one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0)
         out_list.append(one_batch_padded)
     out_padded = torch.stack(out_list)
     return out_padded
 

6.utils/model.py

本文件中主要定义了vocoder模型加载和生成音频的函数

 import os
 import json
 
 import torch
 import numpy as np
 
 import hifigan
 from model import FastSpeech2, ScheduledOptim
 
 
 def get_model(args, configs, device, train=False):
     (preprocess_config, model_config, train_config) = configs
 
     model = FastSpeech2(preprocess_config, model_config).to(device) # 初始化FastSpeech2模型
     if args.restore_step:  # 如果之前有存储的模型参数,就加载
         ckpt_path = os.path.join(
             train_config["path"]["ckpt_path"],
             "{}.pth.tar".format(args.restore_step),
         )
         ckpt = torch.load(ckpt_path)
         model.load_state_dict(ckpt["model"])
 
     # 训练过程
     if train:
         # 初始化优化器
         scheduled_optim = ScheduledOptim(
             model, train_config, model_config, args.restore_step
         )
         # 加载优化器参数
         if args.restore_step:
             scheduled_optim.load_state_dict(ckpt["optimizer"])
         # 设置训练模式
         model.train()
         # 返回模型和优化器
         return model, scheduled_optim
     # 推理过程
     model.eval()
     # 参数不需要计算梯度
     model.requires_grad_ = False
     # 返回模型
     return model
 
 # 计算模型的参数总量
 def get_param_num(model):
     num_param = sum(param.numel() for param in model.parameters())
     return num_param
 
 # 加载vocoder
 def get_vocoder(config, device):
     name = config["vocoder"]["model"]  # hifigan
     speaker = config["vocoder"]["speaker"]  # universal
 
     if name == "MelGAN":
         if speaker == "LJSpeech":
             vocoder = torch.hub.load(
                 "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
             )
         elif speaker == "universal":
             vocoder = torch.hub.load(
                 "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
             )
         vocoder.mel2wav.eval()
         vocoder.mel2wav.to(device)
     # hifigan
     elif name == "HiFi-GAN":
         with open("hifigan/config.json", "r") as f:
             config = json.load(f)
         config = hifigan.AttrDict(config)
         vocoder = hifigan.Generator(config)
         if speaker == "LJSpeech":
             ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar")
         elif speaker == "universal":
             ckpt = torch.load("hifigan/generator_universal.pth.tar")
         vocoder.load_state_dict(ckpt["generator"])
         vocoder.eval()
         vocoder.remove_weight_norm()
         vocoder.to(device)
 
     return vocoder
 
 # vocoder使用mel谱图生成音频
 def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
     name = model_config["vocoder"]["model"]
     with torch.no_grad():
         if name == "MelGAN":
             wavs = vocoder.inverse(mels / np.log(10))
         elif name == "HiFi-GAN":
             wavs = vocoder(mels).squeeze(1)
 
     wavs = (
         wavs.cpu().numpy()
         * preprocess_config["preprocessing"]["audio"]["max_wav_value"]
     ).astype("int16")
     wavs = [wav for wav in wavs]
 
     for i in range(len(mels)):
         if lengths is not None:
             wavs[i] = wavs[i][: lengths[i]]
 
     return wavs
 

7.synthesize.py

本文件主要定义了音频合成的完整过程,即模型的使用流程

 import re
 import argparse
 from string import punctuation
 
 import torch
 import yaml
 import numpy as np
 from torch.utils.data import DataLoader
 from g2p_en import G2p
 from pypinyin import pinyin, Style
 
 from utils.model import get_model, get_vocoder
 from utils.tools import to_device, synth_samples
 from dataset import TextDataset
 from text import text_to_sequence
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 # 加载词典
 def read_lexicon(lex_path):
     lexicon = {}
     with open(lex_path) as f:
         for line in f:
             temp = re.split(r"\s+", line.strip("\n"))
             word = temp[0]
             phones = temp[1:]
             if word.lower() not in lexicon:
                 lexicon[word.lower()] = phones
     return lexicon
 
 # 处理英文,将其转换为音素序列
 def preprocess_english(text, preprocess_config):
     text = text.rstrip(punctuation)
     lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])  # "lexicon/librispeech-lexicon.txt"
 
     g2p = G2p()
     phones = []
     words = re.split(r"([,;.\-\?\!\s+])", text)
     for w in words:
         if w.lower() in lexicon:
             phones += lexicon[w.lower()]
         else:
             phones += list(filter(lambda p: p != " ", g2p(w)))
     phones = "{" + "}{".join(phones) + "}"
     phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
     phones = phones.replace("}{", " ")
 
     print("Raw Text Sequence: {}".format(text))
     print("Phoneme Sequence: {}".format(phones))
     sequence = np.array(
         text_to_sequence(
             phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
         )
     )
 
     return np.array(sequence)
 
 # 将普通话转换为音素序列
 def preprocess_mandarin(text, preprocess_config):
     lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
 
     phones = []
     pinyins = [
         p[0]
         for p in pinyin(
             text, style=Style.TONE3, strict=False, neutral_tone_with_five=True
         )
     ]
     for p in pinyins:
         if p in lexicon:
             phones += lexicon[p]
         else:
             phones.append("sp")
 
     phones = "{" + " ".join(phones) + "}"
     print("Raw Text Sequence: {}".format(text))
     print("Phoneme Sequence: {}".format(phones))
     sequence = np.array(
         text_to_sequence(
             phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
         )
     )
 
     return np.array(sequence)
 
 # 基于由文本转化而来的音素序列生成音频
 def synthesize(model, step, configs, vocoder, batchs, control_values):
     preprocess_config, model_config, train_config = configs
     pitch_control, energy_control, duration_control = control_values
 
     for batch in batchs:
         batch = to_device(batch, device)
         with torch.no_grad():
             # Forward
             output = model(
                 *(batch[2:]),
                 p_control=pitch_control,
                 e_control=energy_control,
                 d_control=duration_control
             )
             synth_samples(
                 batch,
                 output,
                 vocoder,
                 model_config,
                 preprocess_config,
                 train_config["path"]["result_path"],
             )
 
 
 if __name__ == "__main__":
 
     parser = argparse.ArgumentParser()
     parser.add_argument("--restore_step", type=int, required=True)
     # 模式,批量合成或者单句合成
     parser.add_argument(
         "--mode",
         type=str,
         choices=["batch", "single"],
         required=True,
         help="Synthesize a whole dataset or a single sentence",
     )
     # 仅用于批量合成,读取需要合成的文本文件
     parser.add_argument(
         "--source",
         type=str,
         default=None,
         help="path to a source file with format like train.txt and val.txt, for batch mode only",
     )
     # 单句合成时,需要合成的文本
     parser.add_argument(
         "--text",
         type=str,
         default=None,
         help="raw text to synthesize, for single-sentence mode only",
     )
     # 单句合成选择一个说话人进行合成
     parser.add_argument(
         "--speaker_id",
         type=int,
         default=0,
         help="speaker ID for multi-speaker synthesis, for single-sentence mode only",
     )
     # the path to preprocess.yaml
     parser.add_argument(
         "-p",
         "--preprocess_config",
         type=str,
         required=True,
         help="path to preprocess.yaml",
     )
     # the path to model.yaml
     parser.add_argument(
         "-m", "--model_config", type=str, required=True, help="path to model.yaml"
     )
     # the path to train.yaml
     parser.add_argument(
         "-t", "--train_config", type=str, required=True, help="path to train.yaml"
     )
     # 控制系数
     parser.add_argument(
         "--pitch_control",
         type=float,
         default=1.0,
         help="control the pitch of the whole utterance, larger value for higher pitch",
     )
     parser.add_argument(
         "--energy_control",
         type=float,
         default=1.0,
         help="control the energy of the whole utterance, larger value for larger volume",
     )
     parser.add_argument(
         "--duration_control",
         type=float,
         default=1.0,
         help="control the speed of the whole utterance, larger value for slower speaking rate",
     )
     args = parser.parse_args()
 
     # Check source texts
     if args.mode == "batch":
         assert args.source is not None and args.text is None
     if args.mode == "single":
         assert args.source is None and args.text is not None
 
     # Read Config
     preprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)
     model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
     train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
     configs = (preprocess_config, model_config, train_config)
 
     # Get model
     model = get_model(args, configs, device, train=False)
 
     # Load vocoder
     vocoder = get_vocoder(model_config, device)
 
     # Preprocess texts
     if args.mode == "batch":
         # Get dataset
         dataset = TextDataset(args.source, preprocess_config)
         batchs = DataLoader(
             dataset,
             batch_size=8,
             collate_fn=dataset.collate_fn,
         )
     if args.mode == "single":
         ids = raw_texts = [args.text[:100]]
         speakers = np.array([args.speaker_id])
         if preprocess_config["preprocessing"]["text"]["language"] == "en":
             texts = np.array([preprocess_english(args.text, preprocess_config)])
         elif preprocess_config["preprocessing"]["text"]["language"] == "zh":
             texts = np.array([preprocess_mandarin(args.text, preprocess_config)])
         text_lens = np.array([len(texts[0])])
         batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]
 
     control_values = args.pitch_control, args.energy_control, args.duration_control
 
     synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)
 

8.evaluate.py

本文件定义了评估函数

 import argparse
 import os
 
 import torch
 import yaml
 import torch.nn as nn
 from torch.utils.data import DataLoader
 
 from utils.model import get_model, get_vocoder
 from utils.tools import to_device, log, synth_one_sample
 from model import FastSpeech2Loss
 from dataset import Dataset
 
 
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 def evaluate(model, step, configs, logger=None, vocoder=None):
     preprocess_config, model_config, train_config = configs
 
     # Get dataset
     dataset = Dataset(
         "val.txt", preprocess_config, train_config, sort=False, drop_last=False
     )
     batch_size = train_config["optimizer"]["batch_size"]
     loader = DataLoader(
         dataset,
         batch_size=batch_size,
         shuffle=False,
         collate_fn=dataset.collate_fn,
     )
 
     # Get loss function
     Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
 
     # Evaluation
     loss_sums = [0 for _ in range(6)]
     for batchs in loader:
         for batch in batchs:
             batch = to_device(batch, device)
             with torch.no_grad():
                 # Forward
                 output = model(*(batch[2:]))
 
                 # Cal Loss
                 losses = Loss(batch, output)
 
                 for i in range(len(losses)):
                     loss_sums[i] += losses[i].item() * len(batch[0])
 
     loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums]
 
     message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
         *([step] + [l for l in loss_means])
     )
 
     if logger is not None:
         fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
             batch,
             output,
             vocoder,
             model_config,
             preprocess_config,
         )
 
         log(logger, step, losses=loss_means)
         log(
             logger,
             fig=fig,
             tag="Validation/step_{}_{}".format(step, tag),
         )
         sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
         log(
             logger,
             audio=wav_reconstruction,
             sampling_rate=sampling_rate,
             tag="Validation/step_{}_{}_reconstructed".format(step, tag),
         )
         log(
             logger,
             audio=wav_prediction,
             sampling_rate=sampling_rate,
             tag="Validation/step_{}_{}_synthesized".format(step, tag),
         )
 
     return message
 
 
 if __name__ == "__main__":
 
     parser = argparse.ArgumentParser()
     parser.add_argument("--restore_step", type=int, default=30000)
     parser.add_argument(
         "-p",
         "--preprocess_config",
         type=str,
         required=True,
         help="path to preprocess.yaml",
     )
     parser.add_argument(
         "-m", "--model_config", type=str, required=True, help="path to model.yaml"
     )
     parser.add_argument(
         "-t", "--train_config", type=str, required=True, help="path to train.yaml"
     )
     args = parser.parse_args()
 
     # Read Config
     preprocess_config = yaml.load(
         open(args.preprocess_config, "r"), Loader=yaml.FullLoader
     )
     model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
     train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
     configs = (preprocess_config, model_config, train_config)
 
     # Get model
     model = get_model(args, configs, device, train=False).to(device)
 
     message = evaluate(model, args.restore_step, configs)
     print(message)

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
转载请注明出处: https://daima100.com/4188.html

(0)
上一篇 2023-06-30
下一篇 2023-04-01

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注