diff --git a/recipes/Voicebank/enhance/MetricGAN-U/train.py b/recipes/Voicebank/enhance/MetricGAN-U/train.py index a0c68acd4a265c5f61daadc505a9a6b2ea7c2fd8..0bbd7f83700b6ae7b28155e6740113c18944865f 100644 --- a/recipes/Voicebank/enhance/MetricGAN-U/train.py +++ b/recipes/Voicebank/enhance/MetricGAN-U/train.py @@ -109,6 +109,8 @@ def dnsmos_eval(predict, target): pred_wav = pred_wav / max(abs(pred_wav)) pred_wav = pred_wav[np.newaxis, :] + print("predwav_type", type(pred_wav), "shape", pred_wav.shape) + scores = dnsmos_func.run(pred_wav, sr=16000) score = scores["p808_mos"] score = float(sigmoid(score)) # normalize the score to 0~1 @@ -145,6 +147,8 @@ def dnsmos_eval_valid(predict, target): pred_wav = pred_wav.numpy() pred_wav = pred_wav / max(abs(pred_wav)) + print("predwav_type", type(pred_wav), "shape", pred_wav.shape) + scores = dnsmos_func.run(pred_wav, sr=16000) score = scores["p808_mos"] return score