Skip to content
Snippets Groups Projects
Commit b3cf3c0b authored by Joseph Omar's avatar Joseph Omar
Browse files

fixes

parent 43ee2337
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,7 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>] ...@@ -14,7 +14,7 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
data_folder: /scratch/jo15g23/vbd/noisy-vctk-16k # e.g, /data/member1/user_jasonfu/noisy-vctk-16k data_folder: /scratch/jo15g23/vbd/noisy-vctk-16k # e.g, /data/member1/user_jasonfu/noisy-vctk-16k
MetricGAN_folder: !ref <output_folder>/enhanced_wavs MetricGAN_folder: !ref <output_folder>/enhanced_wavs
output_folder: !ref /scratch/jo15g23/metricgan/results/<seed> output_folder: !ref /scratch/jo15g23/metricgan_new/<seed>
save_folder: !ref <output_folder>/save save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt train_log: !ref <output_folder>/train_log.txt
enhanced_folder: !ref <output_folder>/enhanced_wavs enhanced_folder: !ref <output_folder>/enhanced_wavs
......
...@@ -32,8 +32,7 @@ from speechbrain.nnet.loss.stoi_loss import stoi_loss ...@@ -32,8 +32,7 @@ from speechbrain.nnet.loss.stoi_loss import stoi_loss
from speechbrain.utils.distributed import run_on_main from speechbrain.utils.distributed import run_on_main
from speechbrain.dataio.sampler import ReproducibleWeightedRandomSampler from speechbrain.dataio.sampler import ReproducibleWeightedRandomSampler
from speechmos import dnsmos as dnsmos_func
from seframework.utils.dnsmos import DNSMOS
### For DNSMSOS ### For DNSMSOS
# URL for the web service # URL for the web service
...@@ -49,8 +48,6 @@ headers = {"Content-Type": "application/json"} ...@@ -49,8 +48,6 @@ headers = {"Content-Type": "application/json"}
# If authentication is enabled, set the authorization header # If authentication is enabled, set the authorization header
headers["Authorization"] = f"Basic {AUTH_KEY }" headers["Authorization"] = f"Basic {AUTH_KEY }"
dnsmos= DNSMOS(device_id=0, sample_rate=16000, cache_session=True, p835=False)
def sigmoid(x): def sigmoid(x):
s = 1 / (1 + np.exp(-x)) s = 1 / (1 + np.exp(-x))
return s return s
...@@ -112,8 +109,8 @@ def dnsmos_eval(predict, target): ...@@ -112,8 +109,8 @@ def dnsmos_eval(predict, target):
pred_wav = pred_wav / max(abs(pred_wav)) pred_wav = pred_wav / max(abs(pred_wav))
pred_wav = pred_wav[np.newaxis, :] pred_wav = pred_wav[np.newaxis, :]
scores = dnsmos(pred_wav) scores = dnsmos_func.run(pred_wav, sr=16000)
score = scores["mos"] score = scores["p808_mos"]
score = float(sigmoid(score)) # normalize the score to 0~1 score = float(sigmoid(score)) # normalize the score to 0~1
return score return score
...@@ -148,8 +145,8 @@ def dnsmos_eval_valid(predict, target): ...@@ -148,8 +145,8 @@ def dnsmos_eval_valid(predict, target):
pred_wav = pred_wav.numpy() pred_wav = pred_wav.numpy()
pred_wav = pred_wav / max(abs(pred_wav)) pred_wav = pred_wav / max(abs(pred_wav))
scores = dnsmos(pred_wav) scores = dnsmos_func.run(pred_wav, sr=16000)
score = scores["mos"] score = scores["p808_mos"]
return score return score
# data = {"data": pred_wav.tolist()} # data = {"data": pred_wav.tolist()}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment