diff --git a/recipes/DNS/README.md b/recipes/DNS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..82eaf1a5ceefb5a0cacbb9fcb70be76c3f8f8917 --- /dev/null +++ b/recipes/DNS/README.md @@ -0,0 +1,143 @@ +# **Speech Enhancement for Microsoft Deep Noise Suppression (DNS) Challenge – ICASSP 2022** +This repository contains training recipes for a speech enhancement system designed for the 4th Deep Noise Suppression Challenge, organized by Microsoft at Interspeech 2022. <br> +The Deep Noise Suppression Challenge features two distinct tracks: +1. **Real Time Non-Personalized DNS** +2. Real Time Personalized DNS (PDNS) for Fullband Audio + +We focus on implementing solutions only for the first track, which involves real-time non-personalized DNS. + +- **Model and Data** : For this challenge, we employ the [Sepformer model](https://arxiv.org/abs/2010.13154v2) to train our speech enhancement system. Our training utilizes 500 hours of fullband audio. + +- **Evaluation Strategy** : We follow the official evaluation strategy outlined by the ITU-T P.835 subjective test framework. It measures speech quality, background noise quality, and overall audio quality. This is done using [DNSMOS P.835](https://arxiv.org/pdf/2110.01763.pdf), a machine learning-based model capable of predicting SIG (Speech Quality), BAK (Background Noise Quality), and OVRL (Overall Audio Quality). + +**Related links** +- [Official Website](https://www.microsoft.com/en-us/research/academic-program/deep-noise-suppression-challenge-icassp-2022/) +- [DNS-4 ICASSP 2022 github repository](https://github.com/microsoft/DNS-Challenge/tree/5582dcf5ba43155621de72a035eb54a7d233af14) + +## **DNS-4 dataset** +DNS-4 dataset once decompressed, the directory structure and sizes of datasets are: +``` +datasets_fullband 892G ++-- dev_testset 1.7G ++-- impulse_responses 5.9G ++-- noise_fullband 58G +\-- clean_fullband 827G + +-- emotional_speech 2.4G + +-- french_speech 62G + +-- german_speech 319G + +-- italian_speech 42G + +-- read_speech 299G + +-- russian_speech 12G + +-- spanish_speech 65G + +-- vctk_wav48_silence_trimmed 27G + \-- VocalSet_48kHz_mono 974M +``` + +### **Required disk space** +The `dns_download.py` download script downloads the Real-time DNS track data and de-compresses it. The compressed data takes around 550 GB of disk space and when de-compressed you would need 1 TB to store audio files. We bundle this decompressed audio into larger archives called as shards. +However this is not the end, the downloaded clean-audio files, RIRs, and noisy-audio files are further used to synthesize clean-noisy audio pairs for training. Once again, we bundle the synthesized data into shards for efficient and faster accessibility. This means further space will be needed to store the synthesized clean-noisy-noise shards. + +**NOTE** +- This dataset download process can be extremely time-consuming. With a total of 126 splits (train, noise and dev data), the script downloads each split in a serial order. The script also allows concurrent data download (by enabling `--parallel_download` param) by using multiple threads (equal to number of your CPU cores). This is helpful especially when you have access to a large cluster. (Alternatively, you can download all 126 splits and decompress them at once by using array job submission.) + +## **Installing Extra Dependencies** +Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: + +``` +pip install -r extra_requirements.txt +``` + +## **Getting started** +- STEP 1: Download DNS dataset. +- STEP 2: Synthesize noisy data. +- STEP 3: Begin training. + +## Step 1: **Downloading Real-time DNS track dataset and create the Webdataset shards** +The DNS dataset can be downloaded by running the script below +``` +python dns_download.py --compressed_path DNS-dataset --decompressed_path DNS-compressed +``` +To use parallel downloading +``` +python dns_download.py --compressed_path DNS-dataset --decompressed_path DNS-compressed --parallel_download +``` +The compressed files are downloaded in `DNS-compressed` and further decompressed audio files can be found in `DNS-dataset`. + +Next, create webdataset shards +``` +## webdataset shards for clean_fullband (choose one one language i.e. read, german etc. at a time) +python create_wds_shards.py DNS-dataset/datasets_fullband/clean_fullband/<read_speech/german_speech/french_speech/...>/ DNS-shards/clean_fullband/ + +## webdataset shards for noise_fullband +python create_wds_shards.py DNS-dataset/datasets_fullband/noise_fullband/ DNS-shards/noise_fullband + +## webdataset shards for baseline dev-set +python create_wds_shards.py DNS-dataset/datasets_fullband/dev_testset/noisy_testclips/ DNS-shards/devsets_fullband +``` +## Step 2: **Synthesize noisy data and create the Webdataset shards** +To synthesize clean-noisy audio for speech enhancement training (we add noise, RIR to clean fullband speech to synthesize clean-noisy pairs) +``` +cd noisyspeech_synthesizer + +## synthesize read speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name read_speech --synthesized_data_dir synthesized_data_shards + +## synthesize German speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name german_speech --synthesized_data_dir synthesized_data_shards + +## synthesize Italian speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name italian_speech --synthesized_data_dir synthesized_data_shards + +## similarly do for spanish, russian and french. +``` +*For more, please see `noisyspeech_synthesizer` on how to synthesize noisy files from clean audio and noise audio files.* + +## Step 3: **Begin training** +To start training +``` +cd enhancement +python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_shards_data> +``` +*For more details and how to perform evaluation, see `enhancement` folder on details about the main training script* + +# **About SpeechBrain** +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + + +# **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` + + +**Citing SepFormer** +```bibtex +@inproceedings{subakan2021attention, + title={Attention is All You Need in Speech Separation}, + author={Cem Subakan and Mirco Ravanelli and Samuele Cornell and Mirko Bronzi and Jianyuan Zhong}, + year={2021}, + booktitle={ICASSP 2021} +} +``` + +**Citing DNS-4 dataset (ICASSP 2022)** +```bibtex +@inproceedings{dubey2022icassp, + title={ICASSP 2022 Deep Noise Suppression Challenge}, + author={Dubey, Harishchandra and Gopal, Vishak and Cutler, Ross and Matusevych, Sergiy and Braun, Sebastian and Eskimez, Emre Sefik and Thakker, Manthan and Yoshioka, Takuya and Gamper, Hannes and Aichner, Robert}, + booktitle={ICASSP}, + year={2022} +} +``` \ No newline at end of file diff --git a/recipes/DNS/create_wds_shards.py b/recipes/DNS/create_wds_shards.py new file mode 100644 index 0000000000000000000000000000000000000000..fc52ab18da43992d46149b568f3859d7992e3722 --- /dev/null +++ b/recipes/DNS/create_wds_shards.py @@ -0,0 +1,186 @@ +################################################################################ +# +# Converts the uncompressed DNS folder +# {french,german,...}_speech/../<*.wav> +# structure of DNS into a WebDataset format +# +# Author(s): Tanel Alumäe, Nik Vaessen, Sangeet Sagar (2023) +################################################################################ + +import os +import json +from tqdm import tqdm +import pathlib +import argparse +import random +from collections import defaultdict + +import librosa +import torch +import torchaudio +import webdataset as wds + +################################################################################ +# methods for writing the shards + +ID_SEPARATOR = "&" + + +def load_audio(audio_file_path: pathlib.Path) -> torch.Tensor: + t, sr = torchaudio.load(audio_file_path) + + return t + + +def write_shards( + dns_folder_path: pathlib.Path, + shards_path: pathlib.Path, + seed: int, + samples_per_shard: int, + min_dur: float, +): + """ + Parameters + ---------- + dns_folder_path: folder where extracted DNS data is located + shards_path: folder to write shards of data to + seed: random seed used to initially shuffle data into shards + samples_per_shard: number of data samples to store in each shards. + """ + # make sure output folder exist + shards_path.mkdir(parents=True, exist_ok=True) + + # find all audio files + audio_files = sorted([f for f in dns_folder_path.rglob("*.wav")]) + + # create tuples (unique_sample_id, language_id, path_to_audio_file, duration) + data_tuples = [] + + # track statistics on data + all_language_ids = set() + sample_keys_per_language = defaultdict(list) + + if "clean" in dns_folder_path.as_posix(): + delim = "clean_fullband/" + elif "noise" in dns_folder_path.as_posix(): + delim = "noise_fullband/" + lang = "noise" + elif "dev_testset" in dns_folder_path.as_posix(): + delim = "dev_testset/" + lang = "baseline_noisytestset" + else: + delim = os.path.basename(dns_folder_path.as_posix()) + lang = delim + + for f in tqdm(audio_files): + # path should be + # {french,german,...}_speech/../<*.wav> + sub_path = f.as_posix().split(delim)[1] + + loc = f.as_posix() + key = os.path.splitext(os.path.basename(sub_path))[0] + if "clean_fullband" in dns_folder_path.as_posix(): + lang = key.split("_speech")[0] + + dur = librosa.get_duration(path=loc) + + # Period is not allowed in a WebDataset key name + key = key.replace(".", "_") + if dur > min_dur: + # store statistics + all_language_ids.add(lang) + sample_keys_per_language[lang].append(key) + t = (key, lang, loc, dur) + data_tuples.append(t) + + all_language_ids = sorted(all_language_ids) + + # write a meta.json file which contains statistics on the data + # which will be written to shards + meta_dict = { + "language_ids": list(all_language_ids), + "sample_keys_per_language": sample_keys_per_language, + "num_data_samples": len(data_tuples), + } + + with (shards_path / "meta.json").open("w") as f: + json.dump(meta_dict, f, indent=4) + + # shuffle the tuples so that each shard has a large variety in languages + random.seed(seed) + random.shuffle(data_tuples) + + # write shards + all_keys = set() + shards_path.mkdir(exist_ok=True, parents=True) + pattern = str(shards_path / "shard") + "-%06d.tar" + + with wds.ShardWriter(pattern, maxcount=samples_per_shard) as sink: + for key, language_id, f, duration in data_tuples: + # load the audio tensor + tensor = load_audio(f) + + # verify key is unique + assert key not in all_keys + all_keys.add(key) + + # create sample to write + sample = { + "__key__": key, + "audio.pth": tensor, + "language_id": language_id, + } + + # write sample to sink + sink.write(sample) + + +################################################################################ +# define CLI + +parser = argparse.ArgumentParser( + description="Convert DNS-4 to WebDataset shards" +) + +parser.add_argument( + "dns_decompressed_path", + type=pathlib.Path, + help="directory containing the (decompressed) DNS dataset", +) +parser.add_argument( + "shards_path", type=pathlib.Path, help="directory to write shards to" +) +parser.add_argument( + "--seed", + type=int, + default=12345, + help="random seed used for shuffling data before writing to shard", +) +parser.add_argument( + "--samples_per_shard", + type=int, + default=5000, + help="the maximum amount of samples placed in each shard. The last shard " + "will most likely contain fewer samples.", +) +parser.add_argument( + "--min-duration", + type=float, + default=3.0, + help="Minimum duration of the audio", +) + + +################################################################################ +# execute script + +if __name__ == "__main__": + args = parser.parse_args() + + write_shards( + args.dns_decompressed_path, + args.shards_path, + args.seed, + args.samples_per_shard, + args.min_duration, + ) diff --git a/recipes/DNS/dns_download.py b/recipes/DNS/dns_download.py new file mode 100644 index 0000000000000000000000000000000000000000..84381f243029326cd63d15292c61c687ae51215c --- /dev/null +++ b/recipes/DNS/dns_download.py @@ -0,0 +1,600 @@ +#!/usr/bin/env/python3 +""" +Recipe for downloading DNS-4 dataset- training, +baseline DEV noisyset, blind testset +Source: +https://github.com/microsoft/DNS-Challenge +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-4.sh + +Disk-space (compressed): 550 GB +Disk-space (decompressed): 1 TB + +NOTE: + 1. Some of the azure links provided by Microsoft are not perfect and data + download may stop mid-way through the download process. Hence we validate + download size of each of the file. + 2. Instead of using the impulse response files provided in the challenge, + we opt to download them from OPENSLR. OPENSLR offers both real and synthetic + RIRs, while the challenge offers only real RIRs. + +Authors + * Sangeet Sagar 2022 +""" + +import os +import ssl +import shutil +import zipfile +import tarfile +import certifi +import argparse +import fileinput +import requests +import urllib.request +from tqdm.auto import tqdm +from concurrent.futures import ThreadPoolExecutor + +BLOB_NAMES = [ + "clean_fullband/datasets_fullband.clean_fullband.VocalSet_48kHz_mono_000_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.emotional_speech_000_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_000_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_001_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_002_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_003_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_004_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_005_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_006_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_007_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.french_speech_008_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_000_0.00_3.47.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_001_3.47_3.64.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_002_3.64_3.74.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_003_3.74_3.81.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_004_3.81_3.86.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_005_3.86_3.91.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_006_3.91_3.96.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_007_3.96_4.00.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_008_4.00_4.04.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_009_4.04_4.08.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_010_4.08_4.12.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_011_4.12_4.16.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_012_4.16_4.21.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_013_4.21_4.26.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_014_4.26_4.33.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_015_4.33_4.43.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_016_4.43_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_017_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_018_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_019_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_020_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_021_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_022_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_023_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_024_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_025_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_026_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_027_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_028_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_029_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_030_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_031_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_032_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_033_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_034_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_035_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_036_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_037_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_038_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_039_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_040_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_041_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.german_speech_042_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_000_0.00_3.98.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_001_3.98_4.21.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_002_4.21_4.40.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_003_4.40_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_004_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.italian_speech_005_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_000_0.00_3.75.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_001_3.75_3.88.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_002_3.88_3.96.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_003_3.96_4.02.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_004_4.02_4.06.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_005_4.06_4.10.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_006_4.10_4.13.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_007_4.13_4.16.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_008_4.16_4.19.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_009_4.19_4.21.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_010_4.21_4.24.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_011_4.24_4.26.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_012_4.26_4.29.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_013_4.29_4.31.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_014_4.31_4.33.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_015_4.33_4.35.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_016_4.35_4.38.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_017_4.38_4.40.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_018_4.40_4.42.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_019_4.42_4.45.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_020_4.45_4.48.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_021_4.48_4.52.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_022_4.52_4.57.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_023_4.57_4.67.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_024_4.67_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_025_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_026_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_027_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_028_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_029_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_030_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_031_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_032_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_033_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_034_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_035_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_036_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_037_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_038_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.read_speech_039_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.russian_speech_000_0.00_4.31.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.russian_speech_001_4.31_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_000_0.00_4.09.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_001_4.09_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_002_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_003_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_004_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_005_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_006_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_007_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.spanish_speech_008_NA_NA.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_000.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_001.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_002.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_003.tar.bz2", + "clean_fullband/datasets_fullband.clean_fullband.vctk_wav48_silence_trimmed_004.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_000.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_001.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_002.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_003.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_004.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_005.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.audioset_006.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.freesound_000.tar.bz2", + "noise_fullband/datasets_fullband.noise_fullband.freesound_001.tar.bz2", + "datasets_fullband.dev_testset_000.tar.bz2", +] + +AZURE_URL = ( + "https://dns4public.blob.core.windows.net/dns4archive/datasets_fullband" +) + +# Impulse reponse and Blind testset +OTHER_URLS = { + "impulse_responses": [ + "https://www.openslr.org/resources/26/sim_rir_16k.zip", + "https://www.openslr.org/resources/28/rirs_noises.zip", + ], + "blind_testset": [ + "https://dns4public.blob.core.windows.net/dns4archive/blind_testset_bothtracks.zip" + ], +} + +RIR_table_simple_URL = "https://raw.githubusercontent.com/microsoft/DNS-Challenge/0443a12f5e6e7bec310f453cf0d9637ca28e0eea/datasets/acoustic_params/RIR_table_simple.csv" + +SPLIT_LIST = [ + "dev_testset", + "impulse_responses", + "noise_fullband", + "emotional_speech", + "french_speech", + "german_speech", + "italian_speech", + "read_speech", + "russian_speech", + "spanish_speech", + "vctk_wav48_silence_trimmed", + "VocalSet_48kHz_mono", +] + + +def prepare_download(): + """ + Downloads and prepares various data files and resources. It + downloads real-time DNS track data files (train set and dev + noisy set). + """ + # Real-time DNS track (train set + dev noisy set) + for file_url in BLOB_NAMES: + for split in SPLIT_LIST: + if split in file_url: + split_name = split + + split_path = os.path.join(COMPRESSED_PATH, split_name) + if not os.path.exists(split_path): + os.makedirs(split_path) + if not os.path.exists(DECOMPRESSED_PATH): + os.makedirs(DECOMPRESSED_PATH) + + filename = file_url.split("/")[-1] + download_path = os.path.join(split_path, filename) + download_url = AZURE_URL + "/" + file_url + + if not validate_file(download_url, download_path): + if os.path.exists(download_path): + resume_byte_pos = os.path.getsize(download_path) + else: + resume_byte_pos = None + + download_file( + download_url, + download_path, + split_name, + filename, + resume_byte_pos=resume_byte_pos, + ) + else: + print(", \tDownload complete. Skipping") + decompress_file(download_path, DECOMPRESSED_PATH, split_name) + + # Download RIR (impulse response) & BLIND testset + rir_blind_test_download() + + +def rir_blind_test_download(): + """ + Download the RIRs (room impulse responses), and the blind + test set. + """ + # RIR (impulse response) & BLIND testset + for split_name, download_urls in OTHER_URLS.items(): + for file_url in download_urls: + split_path = os.path.join(COMPRESSED_PATH, split_name) + if not os.path.exists(split_path): + os.makedirs(split_path) + + filename = file_url.split("/")[-1] + download_path = os.path.join(split_path, filename) + + if not validate_file(file_url, download_path): + if os.path.exists(download_path): + resume_byte_pos = os.path.getsize(download_path) + else: + resume_byte_pos = None + + download_file( + file_url, + download_path, + split_name, + filename, + resume_byte_pos=resume_byte_pos, + ) + else: + print(", \tDownload complete. Skipping") + decompress_file( + download_path, + os.path.join(DECOMPRESSED_PATH, split_name), + split_name, + ) + + # Download RIRs simple table + file_path = os.path.join( + DECOMPRESSED_PATH, "impulse_responses", "RIR_table_simple.csv" + ) + response = requests.get(RIR_table_simple_URL) + if response.status_code == 200: + with open(file_path, "wb") as file: + file.write(response.content) + print("\nRIR_simple_table downloaded successfully.") + + else: + print( + f"\nFailed to download RIR_simple_table. Status code: {response.status_code}" + ) + + +def download_file( + download_url, download_path, split_name, filename, resume_byte_pos=None +): + """ + Download file from given URL + + Arguments + --------- + download_url : str + URL of file being downloaded + download_path : str + Full path of the file that is to be downloaded + (or already downloaded) + split_name : str + Split name of the file being downloaded + e.g. read_speech + filename : str + Fielname of the file being downloaded + resume_byte_pos: (int, optional) + Starting byte position for resuming the download. + Default is None, which means a fresh download. + + Returns + ------- + bool + If True, the file need not be downloaded again. + Else the download might have failed or is incomplete. + """ + print("Downloading:", split_name, "=>", filename) + resume_header = ( + {"Range": f"bytes={resume_byte_pos}-"} if resume_byte_pos else None + ) + response = requests.get(download_url, headers=resume_header, stream=True) + file_size = int(response.headers.get("Content-Length")) + + mode = "ab" if resume_byte_pos else "wb" + initial_pos = resume_byte_pos if resume_byte_pos else 0 + + with open(download_path, mode) as f: + with tqdm( + total=file_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + initial=initial_pos, + miniters=1, + ) as pbar: + for chunk in response.iter_content(32 * 1024): + f.write(chunk) + pbar.update(len(chunk)) + + # Validate downloaded file + if validate_file(download_url, download_path): + return True + else: + print("Download failed. Moving on.") + return False + + +def download_file_parallel(args): + """ + Downloads a file in parallel using the provided arguments. It + makes use of `download_file` function to download the required file. + + Arguments + --------- + args : tuple + Tuple containing the download URL, download path, split + name, filename, and required bytes to be downloaded. + """ + download_url, download_path, split_name, filename, resume_byte_pos = args + download_file( + download_url, + download_path, + split_name, + filename, + resume_byte_pos=resume_byte_pos, + ) + + +def parallel_download(): + """ + Perform parallel download of files using `using ThreadPoolExecutor`. + """ + with ThreadPoolExecutor() as executor: + futures = [] + for file_url in BLOB_NAMES: + for split in SPLIT_LIST: + if split in file_url: + split_name = split + split_path = os.path.join(COMPRESSED_PATH, split_name) + if not os.path.exists(split_path): + os.makedirs(split_path) + if not os.path.exists(DECOMPRESSED_PATH): + os.makedirs(DECOMPRESSED_PATH) + + filename = file_url.split("/")[-1] + download_path = os.path.join(split_path, filename) + download_url = AZURE_URL + "/" + file_url + + if not validate_file(download_url, download_path): + if os.path.exists(download_path): + resume_byte_pos = os.path.getsize(download_path) + else: + resume_byte_pos = None + args = ( + download_url, + download_path, + split_name, + filename, + resume_byte_pos, + ) + futures.append(executor.submit(download_file_parallel, args)) + # download_file(download_url, download_path, split_name, filename) + # decompress_file(download_path, DECOMPRESSED_PATH) + else: + print(", \tDownload complete. Skipping") + decompress_file(download_path, DECOMPRESSED_PATH, split_name) + + for future in futures: + future.result() + + # Download RIR (impulse response) & BLIND testset + rir_blind_test_download() + + +def decompress_file(file, decompress_path, split_name): + """ + Decompress the downloaded file if the target folder does not exist. + + Arguments + --------- + file : str + Path to the compressed downloaded file + decompress_path : str + Path to store the decompressed audio files + """ + for _, dirs, _ in os.walk(decompress_path): + if split_name in dirs: + print("\tDecompression skipped. Folder already exists.") + return True + + if "sim_rir_16k" in file: + slr26_dir = os.path.join(decompress_path, "SLR26") + if os.path.exists(slr26_dir): + print("\tDecompression skipped. Folder already exists.") + return True + + if "rirs_noises" in file: + slr28_dir = os.path.join(decompress_path, "SLR28") + if os.path.exists(slr28_dir): + print("\tDecompression skipped. Folder already exists.") + return True + + print("\tDecompressing...") + file_extension = os.path.splitext(file)[-1].lower() + if file_extension == ".zip": + zip = zipfile.ZipFile(file, "r") + zip.extractall(decompress_path) + rename_rirs(decompress_path) + + elif file_extension == ".bz2": + tar = tarfile.open(file, "r:bz2") + tar.extractall(decompress_path) + tar.close() + else: + print("Unsupported file format. Only zip and bz2 files are supported.") + # os.remove(file) + + +def rename_rirs(decompress_path): + """ + Rename directories containing simulated room impulse responses + (RIRs). + + Arguments + --------- + decompress_path (str): The path to the directory containing the RIRs + + Returns + ------- + None + """ + try: + os.rename( + os.path.join(decompress_path, "simulated_rirs_16k"), + os.path.join(decompress_path, "SLR26"), + ) + except Exception: + pass + try: + os.rename( + os.path.join(decompress_path, "RIRS_NOISES"), + os.path.join(decompress_path, "SLR28"), + ) + except Exception: + pass + + +def validate_file(download_url, download_path): + """ + Validate the downloaded file and resume the download if needed. + + Arguments + --------- + download_url : str + URL of the file being downloaded + download_path : str + Full path of the file that is to be downloaded + (or already downloaded) + + Returns + ------- + bool + If True, the file need not be downloaded again. + Else, either the file is not yet downloaded or + partially downloaded, thus resume the download. + """ + if not os.path.isfile(download_path): + # File not yet downloaded + return False + + # Get file size in MB + actual_size = urllib.request.urlopen( + download_url, + context=ssl.create_default_context(cafile=certifi.where()), + ).length + + download_size = os.path.getsize(download_path) + + print( + "File: {}, \t downloaded {} MB out of {} MB".format( + download_path.split("/")[-1], + download_size // (1024 * 1024), + actual_size // (1024 * 1024), + ), + end="", + ) + # Set a margin of 100 MB. We skip re-downloading the file if downloaded + # size differs from actual size by max 100 MB. More than this margin, + # re-download is to attempted. + if actual_size - download_size < 100 * 1024 * 1024: + return True + else: + print(", \tIncomplete download. Resuming...") + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download and extract DNS dataset." + ) + parser.add_argument( + "--compressed_path", + type=str, + default="DNS-compressed", + help="Path to store the compressed data.", + ) + parser.add_argument( + "--decompressed_path", + type=str, + default="DNS-dataset", + help="Path to store the decompressed data.", + ) + + parser.add_argument( + "--parallel_download", + action="store_true", + help="Use parallel download.", + ) + + args = parser.parse_args() + + COMPRESSED_PATH = args.compressed_path + DECOMPRESSED_PATH = args.decompressed_path + + if args.parallel_download: + parallel_download() + else: + prepare_download() + + # Modfy contents inside RIR_simple_table.csv + file_path = os.path.join( + DECOMPRESSED_PATH, "impulse_responses", "RIR_table_simple.csv" + ) + full_path = os.path.abspath(os.path.dirname(file_path)) + + replacements = { + "datasets/impulse_responses/SLR26/simulated_rirs_16k": os.path.join( + full_path, "SLR26" + ), + "datasets/impulse_responses/SLR28/RIRS_NOISES": os.path.join( + full_path, "SLR28" + ), + } + + # Perform the replacements directly in the file using fileinput module + with fileinput.FileInput(file_path, inplace=True) as file: + for line in file: + for original, replacement in replacements.items(): + line = line.replace(original, replacement) + print(line, end="") + + if not os.path.exists( + os.path.join("noisyspeech_synthesizer", "RIR_table_simple.csv") + ): + shutil.move(file_path, "noisyspeech_synthesizer") diff --git a/recipes/DNS/enhancement/README.md b/recipes/DNS/enhancement/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ee4f1d01c7659ce5b387da53e4e2fe4099c99e1b --- /dev/null +++ b/recipes/DNS/enhancement/README.md @@ -0,0 +1,74 @@ +# **Speech enhancement with Microsoft DNS dataset** +This folder contains the recipe for speech enhancement on Deep Noise Suppression (DNS) Challenge 4 (ICASSP 2022) dataset using SepFormer. + +For data download and prepration, please refer to the `README.md` in `recipes/DNS/` + +## **Start training** +``` +python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_dev_shards_data> +``` +## **DNSMOS Evaluation on baseline-testclips** +*Reference: [Offical repo](https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS) <br>* +Download the evalution models from [Offical repo](https://github.com/microsoft/DNS-Challenge/tree/master/DNSMOS) and save it under `DNSMOS`. Then, to run DNSMOS evalution on the baseline-testclips saved in the above step. +``` +# Model=SepFormer +python dnsmos_local.py -t results/sepformer-enhancement-16k/1234/save/baseline_audio_results/enhanced_testclips/ -o dnsmos_enhance.csv + +# Model=Noisy +python dnsmos_local.py -t <path-to/datasets_fullband/dev_testset/noisy_testclips/> -o dnsmos_noisy.csv +``` + +## **Results** +1. The DNS challenge doesn't provide the ground-truth clean files for dev test. Therefore, we randomly separate out 5% of training set as valid set so that we can compute valid stats like Si-SNR and PESQ during validation. Here we show validation performance. + + | Sampling rate | Valid Si-SNR | Valid PESQ | HuggingFace link | Full Model link | + |---------------|--------------|------------|-------------------|------------| + | 16k | -10.6 | 2.06 | [HuggingFace](https://huggingface.co/speechbrain/sepformer-dns4-16k-enhancement) | https://www.dropbox.com/sh/d3rp5d3gjysvy7c/AACmwcEkm_IFvaW1lt2GdtQka?dl=0 | + +2. Evaluation on DNS4 2022 baseline dev set using DNSMOS. + + | Model | SIG | BAK | OVRL | + |------------|--------|--------|--------| + | Noisy | 2.984 | 2.560 | 2.205 | + | Baseline: NSNet2| 3.014 | 3.942 | 2.712 | + | **SepFormer** | 2.999 | 3.076 | 2.437 | + +We performed 45 epochs of training for the enhancement using an 8 X RTXA6000 48GB GPU. On average, each epoch took approximately 9.25 hours to complete. **Consider training it for atleast 90-100 epochs for superior performance.** + +**NOTE** +- Refer [NSNet2](https://github.com/microsoft/DNS-Challenge/tree/5582dcf5ba43155621de72a035eb54a7d233af14/NSNet2-baseline) on how to perform enhancement on baseline dev set (noisy testclips) using the baseline model- NSNet2. + +## **Computing power** +Kindly be aware that in terms of computational power, training can be extremely resource demanding due to the dataset's large size and the complexity of the SepFormer model. To handle the size of 1300 hours of clean-noisy pairs, we employed a multi-GPU distributed data-parallel (DDP) training scheme on an Nvidia 8 X RTXA6000 48GB GPU. The training process lasted for 17 days, for just 45 epochs. + +## **About SpeechBrain** +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + + +## **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` + + +**Citing SepFormer** +```bibtex +@inproceedings{subakan2021attention, + title={Attention is All You Need in Speech Separation}, + author={Cem Subakan and Mirco Ravanelli and Samuele Cornell and Mirko Bronzi and Jianyuan Zhong}, + year={2021}, + booktitle={ICASSP 2021} +} +``` diff --git a/recipes/DNS/enhancement/composite_eval.py b/recipes/DNS/enhancement/composite_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..68353caf515ea1b91cc53032864a415aef074cd5 --- /dev/null +++ b/recipes/DNS/enhancement/composite_eval.py @@ -0,0 +1,466 @@ +"""Composite objective enhancement scores in Python (CSIG, CBAK, COVL) + +Taken from https://github.com/facebookresearch/denoiser/blob/master/scripts/matlab_eval.py + +Authors + * adiyoss (https://github.com/adiyoss) +""" + +from scipy.linalg import toeplitz +from tqdm import tqdm +from pesq import pesq +import librosa +import numpy as np +import os +import sys + + +def eval_composite(ref_wav, deg_wav, sample_rate): + """Evaluate audio quality metrics based on reference + and degraded audio signals. + This function computes various audio quality metrics, + including PESQ, CSIG, CBAK, and COVL, based on the + reference and degraded audio signals provided. + """ + ref_wav = ref_wav.reshape(-1) + deg_wav = deg_wav.reshape(-1) + + alpha = 0.95 + len_ = min(ref_wav.shape[0], deg_wav.shape[0]) + ref_wav = ref_wav[:len_] + deg_wav = deg_wav[:len_] + + # Compute WSS measure + wss_dist_vec = wss(ref_wav, deg_wav, sample_rate) + wss_dist_vec = sorted(wss_dist_vec, reverse=False) + wss_dist = np.mean(wss_dist_vec[: int(round(len(wss_dist_vec) * alpha))]) + + # Compute LLR measure + LLR_dist = llr(ref_wav, deg_wav, sample_rate) + LLR_dist = sorted(LLR_dist, reverse=False) + LLRs = LLR_dist + LLR_len = round(len(LLR_dist) * alpha) + llr_mean = np.mean(LLRs[:LLR_len]) + + # Compute the SSNR + snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, sample_rate) + segSNR = np.mean(segsnr_mean) + + # Compute the PESQ + pesq_raw = PESQ(ref_wav, deg_wav, sample_rate) + + Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist + Csig = trim_mos(Csig) + Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR + Cbak = trim_mos(Cbak) + Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist + Covl = trim_mos(Covl) + + return {"pesq": pesq_raw, "csig": Csig, "cbak": Cbak, "covl": Covl} + + +# ----------------------------- HELPERS ------------------------------------ # +def trim_mos(val): + """Trim a value to be within the MOS (Mean Opinion Score) + range [1, 5]. + """ + return min(max(val, 1), 5) + + +def lpcoeff(speech_frame, model_order): + """Calculate linear prediction (LP) coefficients using + the autocorrelation method. + """ + # (1) Compute Autocor lags + winlength = speech_frame.shape[0] + R = [] + for k in range(model_order + 1): + first = speech_frame[: (winlength - k)] + second = speech_frame[k:winlength] + R.append(np.sum(first * second)) + + # (2) Lev-Durbin + a = np.ones((model_order,)) + E = np.zeros((model_order + 1,)) + rcoeff = np.zeros((model_order,)) + E[0] = R[0] + for i in range(model_order): + if i == 0: + sum_term = 0 + else: + a_past = a[:i] + sum_term = np.sum(a_past * np.array(R[i:0:-1])) + rcoeff[i] = (R[i + 1] - sum_term) / E[i] + a[i] = rcoeff[i] + if i > 0: + a[:i] = a_past[:i] - rcoeff[i] * a_past[::-1] + E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i] + acorr = np.array(R, dtype=np.float32) + refcoeff = np.array(rcoeff, dtype=np.float32) + a = a * -1 + lpparams = np.array([1] + list(a), dtype=np.float32) + acorr = np.array(acorr, dtype=np.float32) + refcoeff = np.array(refcoeff, dtype=np.float32) + lpparams = np.array(lpparams, dtype=np.float32) + + return acorr, refcoeff, lpparams + + +# -------------------------------------------------------------------------- # + +# ---------------------- Speech Quality Metric ----------------------------- # +def PESQ(ref_wav, deg_wav, sample_rate): + """Compute PESQ score. + """ + psq_mode = "wb" if sample_rate == 16000 else "nb" + return pesq(sample_rate, ref_wav, deg_wav, psq_mode) + + +def SSNR(ref_wav, deg_wav, srate=16000, eps=1e-10): + """ Segmental Signal-to-Noise Ratio Objective Speech Quality Measure + This function implements the segmental signal-to-noise ratio + as defined in [1, p. 45] (see Equation 2.12). + """ + clean_speech = ref_wav + processed_speech = deg_wav + clean_length = ref_wav.shape[0] + + # scale both to have same dynamic range. Remove DC too. + clean_speech -= clean_speech.mean() + processed_speech -= processed_speech.mean() + processed_speech *= np.max(np.abs(clean_speech)) / np.max( + np.abs(processed_speech) + ) + + # Signal-to-Noise Ratio + dif = ref_wav - deg_wav + overall_snr = 10 * np.log10( + np.sum(ref_wav ** 2) / (np.sum(dif ** 2) + 10e-20) + ) + # global variables + winlength = int(np.round(30 * srate / 1000)) # 30 msecs + skiprate = winlength // 4 + MIN_SNR = -10 + MAX_SNR = 35 + + # For each frame, calculate SSNR + num_frames = int(clean_length / skiprate - (winlength / skiprate)) + start = 0 + time = np.linspace(1, winlength, winlength) / (winlength + 1) + window = 0.5 * (1 - np.cos(2 * np.pi * time)) + segmental_snr = [] + + for frame_count in range(int(num_frames)): + # (1) get the frames for the test and ref speech. + # Apply Hanning Window + clean_frame = clean_speech[start : start + winlength] + processed_frame = processed_speech[start : start + winlength] + clean_frame = clean_frame * window + processed_frame = processed_frame * window + + # (2) Compute Segmental SNR + signal_energy = np.sum(clean_frame ** 2) + noise_energy = np.sum((clean_frame - processed_frame) ** 2) + segmental_snr.append( + 10 * np.log10(signal_energy / (noise_energy + eps) + eps) + ) + segmental_snr[-1] = max(segmental_snr[-1], MIN_SNR) + segmental_snr[-1] = min(segmental_snr[-1], MAX_SNR) + start += int(skiprate) + return overall_snr, segmental_snr + + +def wss(ref_wav, deg_wav, srate): + """ Calculate Weighted Spectral Slope (WSS) distortion + measure between reference and degraded audio signals. + This function computes the WSS distortion measure using + critical band filters and spectral slope differences. + """ + clean_speech = ref_wav + processed_speech = deg_wav + clean_length = ref_wav.shape[0] + processed_length = deg_wav.shape[0] + + assert clean_length == processed_length, clean_length + + winlength = round(30 * srate / 1000.0) # 240 wlen in samples + skiprate = np.floor(winlength / 4) + max_freq = srate / 2 + num_crit = 25 # num of critical bands + + n_fft = int(2 ** np.ceil(np.log(2 * winlength) / np.log(2))) + n_fftby2 = int(n_fft / 2) + Kmax = 20 + Klocmax = 1 + + # Critical band filter definitions (Center frequency and BW in Hz) + cent_freq = [ + 50.0, + 120, + 190, + 260, + 330, + 400, + 470, + 540, + 617.372, + 703.378, + 798.717, + 904.128, + 1020.38, + 1148.30, + 1288.72, + 1442.54, + 1610.70, + 1794.16, + 1993.93, + 2211.08, + 2446.71, + 2701.97, + 2978.04, + 3276.17, + 3597.63, + ] + bandwidth = [ + 70.0, + 70, + 70, + 70, + 70, + 70, + 70, + 77.3724, + 86.0056, + 95.3398, + 105.411, + 116.256, + 127.914, + 140.423, + 153.823, + 168.154, + 183.457, + 199.776, + 217.153, + 235.631, + 255.255, + 276.072, + 298.126, + 321.465, + 346.136, + ] + + bw_min = bandwidth[0] # min critical bandwidth + + # set up critical band filters. Note here that Gaussianly shaped filters + # are used. Also, the sum of the filter weights are equivalent for each + # critical band filter. Filter less than -30 dB and set to zero. + min_factor = np.exp(-30.0 / (2 * 2.303)) # -30 dB point of filter + + crit_filter = np.zeros((num_crit, n_fftby2)) + all_f0 = [] + for i in range(num_crit): + f0 = (cent_freq[i] / max_freq) * (n_fftby2) + all_f0.append(np.floor(f0)) + bw = (bandwidth[i] / max_freq) * (n_fftby2) + norm_factor = np.log(bw_min) - np.log(bandwidth[i]) + j = list(range(n_fftby2)) + crit_filter[i, :] = np.exp( + -11 * (((j - np.floor(f0)) / bw) ** 2) + norm_factor + ) + crit_filter[i, :] = crit_filter[i, :] * (crit_filter[i, :] > min_factor) + + # For each frame of input speech, compute Weighted Spectral Slope Measure + num_frames = int(clean_length / skiprate - (winlength / skiprate)) + start = 0 # starting sample + time = np.linspace(1, winlength, winlength) / (winlength + 1) + window = 0.5 * (1 - np.cos(2 * np.pi * time)) + distortion = [] + + for frame_count in range(num_frames): + # (1) Get the Frames for the test and reference speeech. + # Multiply by Hanning window. + clean_frame = clean_speech[start : start + winlength] + processed_frame = processed_speech[start : start + winlength] + clean_frame = clean_frame * window + processed_frame = processed_frame * window + + # (2) Compuet Power Spectrum of clean and processed + clean_spec = np.abs(np.fft.fft(clean_frame, n_fft)) ** 2 + processed_spec = np.abs(np.fft.fft(processed_frame, n_fft)) ** 2 + clean_energy = [None] * num_crit + processed_energy = [None] * num_crit + + # (3) Compute Filterbank output energies (in dB) + for i in range(num_crit): + clean_energy[i] = np.sum(clean_spec[:n_fftby2] * crit_filter[i, :]) + processed_energy[i] = np.sum( + processed_spec[:n_fftby2] * crit_filter[i, :] + ) + clean_energy = np.array(clean_energy).reshape(-1, 1) + eps = np.ones((clean_energy.shape[0], 1)) * 1e-10 + clean_energy = np.concatenate((clean_energy, eps), axis=1) + clean_energy = 10 * np.log10(np.max(clean_energy, axis=1)) + processed_energy = np.array(processed_energy).reshape(-1, 1) + processed_energy = np.concatenate((processed_energy, eps), axis=1) + processed_energy = 10 * np.log10(np.max(processed_energy, axis=1)) + + # (4) Compute Spectral Shape (dB[i+1] - dB[i]) + clean_slope = clean_energy[1:num_crit] - clean_energy[: num_crit - 1] + processed_slope = ( + processed_energy[1:num_crit] - processed_energy[: num_crit - 1] + ) + + # (5) Find the nearest peak locations in the spectra to each + # critical band. If the slope is negative, we search + # to the left. If positive, we search to the right. + clean_loc_peak = [] + processed_loc_peak = [] + for i in range(num_crit - 1): + if clean_slope[i] > 0: + # search to the right + n = i + while n < num_crit - 1 and clean_slope[n] > 0: + n += 1 + clean_loc_peak.append(clean_energy[n - 1]) + else: + # search to the left + n = i + while n >= 0 and clean_slope[n] <= 0: + n -= 1 + clean_loc_peak.append(clean_energy[n + 1]) + # find the peaks in the processed speech signal + if processed_slope[i] > 0: + n = i + while n < num_crit - 1 and processed_slope[n] > 0: + n += 1 + processed_loc_peak.append(processed_energy[n - 1]) + else: + n = i + while n >= 0 and processed_slope[n] <= 0: + n -= 1 + processed_loc_peak.append(processed_energy[n + 1]) + + # (6) Compuet the WSS Measure for this frame. This includes + # determination of the weighting functino + dBMax_clean = max(clean_energy) + dBMax_processed = max(processed_energy) + + # The weights are calculated by averaging individual + # weighting factors from the clean and processed frame. + # These weights W_clean and W_processed should range + # from 0 to 1 and place more emphasis on spectral + # peaks and less emphasis on slope differences in spectral + # valleys. This procedure is described on page 1280 of + # Klatt's 1982 ICASSP paper. + clean_loc_peak = np.array(clean_loc_peak) + processed_loc_peak = np.array(processed_loc_peak) + Wmax_clean = Kmax / (Kmax + dBMax_clean - clean_energy[: num_crit - 1]) + Wlocmax_clean = Klocmax / ( + Klocmax + clean_loc_peak - clean_energy[: num_crit - 1] + ) + W_clean = Wmax_clean * Wlocmax_clean + Wmax_processed = Kmax / ( + Kmax + dBMax_processed - processed_energy[: num_crit - 1] + ) + Wlocmax_processed = Klocmax / ( + Klocmax + processed_loc_peak - processed_energy[: num_crit - 1] + ) + W_processed = Wmax_processed * Wlocmax_processed + W = (W_clean + W_processed) / 2 + distortion.append( + np.sum( + W + * ( + clean_slope[: num_crit - 1] + - processed_slope[: num_crit - 1] + ) + ** 2 + ) + ) + + # this normalization is not part of Klatt's paper, but helps + # to normalize the meaasure. Here we scale the measure by the sum of the + # weights + distortion[frame_count] = distortion[frame_count] / np.sum(W) + start += int(skiprate) + return distortion + + +def llr(ref_wav, deg_wav, srate): + """Calculate Log Likelihood Ratio (LLR) distortion measure + between reference and degraded audio signals. This function + computes the LLR distortion measure between reference and + degraded audio signals using LPC analysis and autocorrelation + logs. + """ + clean_speech = ref_wav + processed_speech = deg_wav + clean_length = ref_wav.shape[0] + processed_length = deg_wav.shape[0] + assert clean_length == processed_length, clean_length + + winlength = round(30 * srate / 1000.0) # 240 wlen in samples + skiprate = np.floor(winlength / 4) + if srate < 10000: + # LPC analysis order + P = 10 + else: + P = 16 + + # For each frame of input speech, calculate the Log Likelihood Ratio + num_frames = int(clean_length / skiprate - (winlength / skiprate)) + start = 0 + time = np.linspace(1, winlength, winlength) / (winlength + 1) + window = 0.5 * (1 - np.cos(2 * np.pi * time)) + distortion = [] + + for frame_count in range(num_frames): + # (1) Get the Frames for the test and reference speeech. + # Multiply by Hanning window. + clean_frame = clean_speech[start : start + winlength] + processed_frame = processed_speech[start : start + winlength] + clean_frame = clean_frame * window + processed_frame = processed_frame * window + + # (2) Get the autocorrelation logs and LPC params used + # to compute the LLR measure + R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P) + R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P) + A_clean = A_clean[None, :] + A_processed = A_processed[None, :] + + # (3) Compute the LLR measure + numerator = A_processed.dot(toeplitz(R_clean)).dot(A_processed.T) + denominator = A_clean.dot(toeplitz(R_clean)).dot(A_clean.T) + + if (numerator / denominator) <= 0: + print(f"Numerator: {numerator}") + print(f"Denominator: {denominator}") + + log_ = np.log(numerator / denominator) + distortion.append(np.squeeze(log_)) + start += int(skiprate) + return np.nan_to_num(np.array(distortion)) + + +# -------------------------------------------------------------------------- # + +if __name__ == "__main__": + clean_path = sys.argv[1] + enhanced_path = sys.argv[2] + csig, cbak, covl, count = 0, 0, 0, 0 + for _file in tqdm(os.listdir(clean_path)): + if _file.endswith("wav"): + clean_path_f = os.path.join(clean_path, _file) + enhanced_path_f = os.path.join( + enhanced_path, _file[:-4] + "_enhanced.wav" + ) + clean_sig = librosa.load(clean_path_f, sr=None)[0] + enhanced_sig = librosa.load(enhanced_path_f, sr=None)[0] + res = eval_composite(clean_sig, enhanced_sig) + csig += res["csig"] + cbak += res["cbak"] + covl += res["covl"] + pesq += res["pesq"] + count += 1 + print(f"CSIG: {csig/count}, CBAK: {cbak/count}, COVL: {covl/count}") diff --git a/recipes/DNS/enhancement/dnsmos_local.py b/recipes/DNS/enhancement/dnsmos_local.py new file mode 100644 index 0000000000000000000000000000000000000000..0e334e88527737b9b3f10d918aec81aed6eaae2c --- /dev/null +++ b/recipes/DNS/enhancement/dnsmos_local.py @@ -0,0 +1,195 @@ +""" +Usage: + python dnsmos_local.py -t path/to/sepformer_enhc_clips -o dnsmos_enhance.csv + +Ownership: Microsoft +""" + +import argparse +import concurrent.futures +import glob +import os + +import librosa +import numpy as np +import onnxruntime as ort +import pandas as pd +import soundfile as sf +from tqdm import tqdm + +SAMPLING_RATE = 16000 +INPUT_LENGTH = 9.01 + + +class ComputeScore: + """A class for computing MOS scores using an ONNX model and polynomial fitting. + """ + + def __init__(self, primary_model_path) -> None: + """Initialize the ComputeScore class. + """ + self.onnx_sess = ort.InferenceSession(primary_model_path) + + def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS): + """Calculate MOS scores using polynomial fitting. + Returns a tuple containing MOS scores for speech, + background, and overall quality. + """ + # if is_personalized_MOS: + # p_ovr = np.poly1d([-0.00533021, 0.005101 , 1.18058466, -0.11236046]) + # p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + # p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611 , 0.96883132]) + # else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + sig_poly = p_sig(sig) + bak_poly = p_bak(bak) + ovr_poly = p_ovr(ovr) + + return sig_poly, bak_poly, ovr_poly + + def __call__(self, fpath, sampling_rate, is_personalized_MOS): + """Compute MOS scores for an audio segment. + """ + aud, input_fs = sf.read(fpath) + fs = sampling_rate + if input_fs != fs: + audio = librosa.resample(aud, input_fs, fs) + else: + audio = aud + actual_audio_len = len(audio) + len_samples = int(INPUT_LENGTH * fs) + while len(audio) < len_samples: + audio = np.append(audio, audio) + + num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1 + hop_len_samples = fs + predicted_mos_sig_seg_raw = [] + predicted_mos_bak_seg_raw = [] + predicted_mos_ovr_seg_raw = [] + predicted_mos_sig_seg = [] + predicted_mos_bak_seg = [] + predicted_mos_ovr_seg = [] + + for idx in range(num_hops): + audio_seg = audio[ + int(idx * hop_len_samples) : int( + (idx + INPUT_LENGTH) * hop_len_samples + ) + ] + if len(audio_seg) < len_samples: + continue + + input_features = np.array(audio_seg).astype("float32")[ + np.newaxis, : + ] + oi = {"input_1": input_features} + mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run( + None, oi + )[0][0] + mos_sig, mos_bak, mos_ovr = self.get_polyfit_val( + mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized_MOS=0 + ) + predicted_mos_sig_seg_raw.append(mos_sig_raw) + predicted_mos_bak_seg_raw.append(mos_bak_raw) + predicted_mos_ovr_seg_raw.append(mos_ovr_raw) + predicted_mos_sig_seg.append(mos_sig) + predicted_mos_bak_seg.append(mos_bak) + predicted_mos_ovr_seg.append(mos_ovr) + + clip_dict = { + "filename": fpath, + "len_in_sec": actual_audio_len / fs, + "sr": fs, + } + clip_dict["num_hops"] = num_hops + clip_dict["OVRL_raw"] = np.mean(predicted_mos_ovr_seg_raw) + clip_dict["SIG_raw"] = np.mean(predicted_mos_sig_seg_raw) + clip_dict["BAK_raw"] = np.mean(predicted_mos_bak_seg_raw) + clip_dict["OVRL"] = np.mean(predicted_mos_ovr_seg) + clip_dict["SIG"] = np.mean(predicted_mos_sig_seg) + clip_dict["BAK"] = np.mean(predicted_mos_bak_seg) + return clip_dict + + +def main(args): + models = glob.glob(os.path.join(args.testset_dir, "*")) + audio_clips_list = [] + + if args.personalized_MOS: + primary_model_path = os.path.join("pDNSMOS", "sig_bak_ovr.onnx") + else: + primary_model_path = os.path.join("DNSMOS", "sig_bak_ovr.onnx") + + compute_score = ComputeScore(primary_model_path) + + rows = [] + clips = [] + clips = glob.glob(os.path.join(args.testset_dir, "*.wav")) + is_personalized_eval = args.personalized_MOS + desired_fs = SAMPLING_RATE + for m in tqdm(models): + max_recursion_depth = 10 + audio_path = os.path.join(args.testset_dir, m) + audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav")) + while len(audio_clips_list) == 0 and max_recursion_depth > 0: + audio_path = os.path.join(audio_path, "**") + audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav")) + max_recursion_depth -= 1 + clips.extend(audio_clips_list) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_url = { + executor.submit( + compute_score, clip, desired_fs, is_personalized_eval + ): clip + for clip in clips + } + for future in tqdm(concurrent.futures.as_completed(future_to_url)): + clip = future_to_url[future] + try: + data = future.result() + except Exception as exc: + print("%r generated an exception: %s" % (clip, exc)) + else: + rows.append(data) + + df = pd.DataFrame(rows) + if args.csv_path: + csv_path = args.csv_path + df.to_csv(csv_path) + else: + print(df.describe()) + + print("======== DNSMOS scores ======== ") + print("SIG:", df.loc[:, "SIG"].mean()) + print("BAK:", df.loc[:, "BAK"].mean()) + print("OVRL:", df.loc[:, "OVRL"].mean()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--testset_dir", + default=".", + help="Path to the dir containing audio clips in .wav to be evaluated", + ) + parser.add_argument( + "-o", + "--csv_path", + default=None, + help="Dir to the csv that saves the results", + ) + parser.add_argument( + "-p", + "--personalized_MOS", + action="store_true", + help="Flag to indicate if personalized MOS score is needed or regular", + ) + + args = parser.parse_args() + + main(args) diff --git a/recipes/DNS/enhancement/extra_requirements.txt b/recipes/DNS/enhancement/extra_requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c9d0e2fce43d3e973ed4333d087bc51019fc987 --- /dev/null +++ b/recipes/DNS/enhancement/extra_requirements.txt @@ -0,0 +1,8 @@ +librosa +mir_eval +onnxruntime +pesq +pyroomacoustics==0.3.1 +pystoi +tensorboard +webdataset diff --git a/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml b/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1a807d16e189ad99b18730675d63b4d78b600865 --- /dev/null +++ b/recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml @@ -0,0 +1,183 @@ +# ################################ +# Model: SepFormer model for speech enhancement +# https://arxiv.org/abs/2010.13154 +# +# Author: Sangeet Sagar 2022 +# Dataset : Microsoft-DNS 4 +# ################################ + +# Basic parameters +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1234 +__set_seed: !apply:torch.manual_seed [!ref <seed>] +output_folder: !ref results/sepformer-enhancement-16k/<seed> +save_folder: !ref <output_folder>/save +train_log: !ref <output_folder>/train_log.txt + +# Data params +data_folder: !PLACEHOLDER # ../noisyspeech_synthesizer/synthesized_data_shards/ +train_data: !ref <data_folder>/train_shards/ +valid_data: !ref <data_folder>/valid_shards/ +baseline_noisy_shards_folder: !PLACEHOLDER # ../DNS-shards/devsets_fullband/ +baseline_shards: !ref <baseline_noisy_shards_folder>/shard-{000000..999999}.tar + +# Set to a directory on a large disk if using Webdataset shards hosted on the web. +shard_cache_dir: + +# Basic parameters +use_tensorboard: True +tensorboard_logs: !ref <output_folder>/logs/ +dereverberate: False + +# Experiment params +auto_mix_prec: True +test_only: False +num_spks: 1 +noprogressbar: False +save_audio: True # Save estimated sources on disk +sample_rate: 16000 +audio_length: 4 # seconds +n_audio_to_save: 20 + +# Training parameters +N_epochs: 100 +batch_size: 4 +batch_size_test: 1 +lr: 0.00015 +clip_grad_norm: 5 +loss_upper_lim: 999999 # this is the upper limit for an acceptable loss +# if True, the training sequences are cut to a specified length +limit_training_signal_len: False +# this is the length of sequences if we choose to limit +# the signal length of training sequences +training_signal_len: 32000 +ckpt_interval_minutes: 60 + +# Parameters for data augmentation +use_wavedrop: False +use_speedperturb: True +use_rand_shift: False +min_shift: -8000 +max_shift: 8000 + +speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + perturb_prob: 1.0 + drop_freq_prob: 0.0 + drop_chunk_prob: 0.0 + sample_rate: !ref <sample_rate> + speeds: [95, 100, 105] + +wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + perturb_prob: 0.0 + drop_freq_prob: 1.0 + drop_chunk_prob: 1.0 + sample_rate: !ref <sample_rate> + +# loss thresholding -- this thresholds the training loss +threshold_byloss: True +threshold: -30 + +# Encoder parameters +N_encoder_out: 256 +out_channels: 256 +kernel_size: 16 +kernel_stride: 8 + +# Dataloader options +dataloader_opts: + batch_size: !ref <batch_size> + num_workers: 3 + +dataloader_opts_valid: + batch_size: !ref <batch_size> + num_workers: 3 + +dataloader_opts_test: + batch_size: !ref <batch_size_test> + num_workers: 3 + +# Specifying the network +Encoder: !new:speechbrain.lobes.models.dual_path.Encoder + kernel_size: !ref <kernel_size> + out_channels: !ref <N_encoder_out> + +SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock + num_layers: 8 + d_model: !ref <out_channels> + nhead: 8 + d_ffn: 1024 + dropout: 0 + use_positional_encoding: True + norm_before: True + +SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock + num_layers: 8 + d_model: !ref <out_channels> + nhead: 8 + d_ffn: 1024 + dropout: 0 + use_positional_encoding: True + norm_before: True + +MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model + num_spks: !ref <num_spks> + in_channels: !ref <N_encoder_out> + out_channels: !ref <out_channels> + num_layers: 2 + K: 250 + intra_model: !ref <SBtfintra> + inter_model: !ref <SBtfinter> + norm: ln + linear_layer_after_inter_intra: False + skip_around_intra: True + +Decoder: !new:speechbrain.lobes.models.dual_path.Decoder + in_channels: !ref <N_encoder_out> + out_channels: 1 + kernel_size: !ref <kernel_size> + stride: !ref <kernel_stride> + bias: False + +optimizer: !name:torch.optim.Adam + lr: !ref <lr> + weight_decay: 0 + +loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper + +lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau + factor: 0.5 + patience: 2 + dont_halve_until_epoch: 85 + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref <N_epochs> + +modules: + encoder: !ref <Encoder> + decoder: !ref <Decoder> + masknet: !ref <MaskNet> + +save_all_checkpoints: False +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref <save_folder> + recoverables: + encoder: !ref <Encoder> + decoder: !ref <Decoder> + masknet: !ref <MaskNet> + counter: !ref <epoch_counter> + lr_scheduler: !ref <lr_scheduler> + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref <train_log> + +## Uncomment to fine-tune a pre-trained model. +# pretrained_enhancement: !new:speechbrain.utils.parameter_transfer.Pretrainer +# collect_in: !ref <save_folder> +# loadables: +# encoder: !ref <Encoder> +# decoder: !ref <Decoder> +# masknet: !ref <MaskNet> +# paths: +# encoder: !PLACEHOLDER +# decoder: !PLACEHOLDER +# masknet: !PLACEHOLDER diff --git a/recipes/DNS/enhancement/train.py b/recipes/DNS/enhancement/train.py new file mode 100755 index 0000000000000000000000000000000000000000..0fbff02d34bebca053a296b7dc50cce8ada3e50c --- /dev/null +++ b/recipes/DNS/enhancement/train.py @@ -0,0 +1,864 @@ +#!/usr/bin/env/python3 +"""Recipe for training a speech enhancement system on Microsoft DNS +(Deep Noise Suppression) challenge dataset using SepFormer architecture. +The system employs an encoder,a decoder, and a masking network. + +To run this recipe, do the following: +python train.py hparams/sepformer-dns-16k.yaml --data_folder <path/to/synthesized_shards_data> --baseline_noisy_shards_folder <path/to/baseline_shards_data> + +The experiment file is flexible enough to support different neural +networks. By properly changing the parameter files, you can try +different architectures. + +Authors + * Sangeet Sagar 2022 + * Cem Subakan 2020 + * Mirco Ravanelli 2020 + * Samuele Cornell 2020 + * Mirko Bronzi 2020 + * Jianyuan Zhong 2020 +""" + +import os +import glob +import sys +import csv +import json +import logging +import numpy as np +from tqdm import tqdm +from typing import Dict +from functools import partial + +import torch +import torchaudio +import braceexpand +import webdataset as wds +import torch.nn.functional as F +from torch.cuda.amp import autocast + +import speechbrain as sb +from hyperpyyaml import load_hyperpyyaml +from composite_eval import eval_composite +import speechbrain.nnet.schedulers as schedulers +from speechbrain.utils.distributed import run_on_main +from speechbrain.utils.metric_stats import MetricStats +from speechbrain.processing.features import spectral_magnitude +from speechbrain.dataio.batch import PaddedBatch + +from pesq import pesq +from pystoi import stoi + + +# Define training procedure +class Enhancement(sb.Brain): + def compute_forward(self, noisy, clean, stage, noise=None): + """Forward computations from the noisy to the separated signals.""" + # Unpack lists and put tensors in the right device + noisy, noisy_lens = noisy + noisy, noisy_lens = noisy.to(self.device), noisy_lens.to(self.device) + # Convert clean to tensor + clean = clean[0].unsqueeze(-1).to(self.device) + + # Add speech distortions + if stage == sb.Stage.TRAIN: + with torch.no_grad(): + if self.hparams.use_speedperturb or self.hparams.use_rand_shift: + noisy, clean = self.add_speed_perturb(clean, noisy_lens) + + # Reverb already added, not adding any reverb + clean_rev = clean + noisy = clean.sum(-1) + # if we reverberate, we set the clean to be reverberant + if not self.hparams.dereverberate: + clean = clean_rev + + noise = noise.to(self.device) + len_noise = noise.shape[1] + len_noisy = noisy.shape[1] + min_len = min(len_noise, len_noisy) + + # add the noise + noisy = noisy[:, :min_len] + noise[:, :min_len] + + # fix the length of clean also + clean = clean[:, :min_len, :] + + if self.hparams.use_wavedrop: + noisy = self.hparams.wavedrop(noisy, noisy_lens) + + if self.hparams.limit_training_signal_len: + noisy, clean = self.cut_signals(noisy, clean) + + # Enhancement + if self.use_freq_domain: + noisy_w = self.compute_feats(noisy) + est_mask = self.modules.masknet(noisy_w) + + sep_h = noisy_w * est_mask + est_source = self.hparams.resynth(torch.expm1(sep_h), noisy) + else: + noisy_w = self.hparams.Encoder(noisy) + est_mask = self.modules.masknet(noisy_w) + + sep_h = noisy_w * est_mask + est_source = self.hparams.Decoder(sep_h[0]) + + # T changed after conv1d in encoder, fix it here + T_origin = noisy.size(1) + T_est = est_source.size(1) + est_source = est_source.squeeze(-1) + if T_origin > T_est: + est_source = F.pad(est_source, (0, T_origin - T_est)) + else: + est_source = est_source[:, :T_origin] + + return [est_source, sep_h], clean.squeeze(-1) + + def compute_feats(self, wavs): + """Feature computation pipeline""" + feats = self.hparams.Encoder(wavs) + feats = spectral_magnitude(feats, power=0.5) + feats = torch.log1p(feats) + return feats + + def compute_objectives(self, predictions, clean): + """Computes the si-snr loss""" + predicted_wavs, predicted_specs = predictions + + if self.use_freq_domain: + target_specs = self.compute_feats(clean) + return self.hparams.loss(target_specs, predicted_specs) + else: + return self.hparams.loss( + clean.unsqueeze(-1), predicted_wavs.unsqueeze(-1) + ) + + def fit_batch(self, batch): + """Trains one batch""" + # Unpacking batch list + noisy = batch.noisy_sig + clean = batch.clean_sig + noise = batch.noise_sig[0] + + if self.auto_mix_prec: + with autocast(): + predictions, clean = self.compute_forward( + noisy, clean, sb.Stage.TRAIN, noise + ) + loss = self.compute_objectives(predictions, clean) + + # hard threshold the easy dataitems + if self.hparams.threshold_byloss: + th = self.hparams.threshold + loss_to_keep = loss[loss > th] + if loss_to_keep.nelement() > 0: + loss = loss_to_keep.mean() + else: + loss = loss.mean() + + if ( + loss < self.hparams.loss_upper_lim and loss.nelement() > 0 + ): # the fix for computational problems + self.scaler.scale(loss).backward() + if self.hparams.clip_grad_norm >= 0: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_( + self.modules.parameters(), self.hparams.clip_grad_norm, + ) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.nonfinite_count += 1 + logger.info( + "infinite loss or empty loss! it happened {} times so far - skipping this batch".format( + self.nonfinite_count + ) + ) + loss.data = torch.tensor(0).to(self.device) + else: + predictions, clean = self.compute_forward( + noisy, clean, sb.Stage.TRAIN, noise + ) + loss = self.compute_objectives(predictions, clean) + + if self.hparams.threshold_byloss: + th = self.hparams.threshold + loss_to_keep = loss[loss > th] + if loss_to_keep.nelement() > 0: + loss = loss_to_keep.mean() + else: + loss = loss.mean() + + if ( + loss < self.hparams.loss_upper_lim and loss.nelement() > 0 + ): # the fix for computational problems + loss.backward() + if self.hparams.clip_grad_norm >= 0: + torch.nn.utils.clip_grad_norm_( + self.modules.parameters(), self.hparams.clip_grad_norm + ) + self.optimizer.step() + else: + self.nonfinite_count += 1 + logger.info( + "infinite loss or empty loss! it happened {} times so far - skipping this batch".format( + self.nonfinite_count + ) + ) + loss.data = torch.tensor(0).to(self.device) + self.optimizer.zero_grad() + + return loss.detach().cpu() + + def evaluate_batch(self, batch, stage): + """Computations needed for validation/test batches""" + + snt_id = batch.id + noisy = batch.noisy_sig + clean = batch.clean_sig + + with torch.no_grad(): + predictions, clean = self.compute_forward(noisy, clean, stage) + loss = self.compute_objectives(predictions, clean) + loss = torch.mean(loss) + + if stage != sb.Stage.TRAIN: + self.pesq_metric.append( + ids=batch.id, predict=predictions[0].cpu(), target=clean.cpu() + ) + + # Manage audio file saving + if stage == sb.Stage.TEST and self.hparams.save_audio: + if hasattr(self.hparams, "n_audio_to_save"): + if self.hparams.n_audio_to_save > 0: + self.save_audio(snt_id[0], noisy, clean, predictions[0]) + self.hparams.n_audio_to_save += -1 + else: + self.save_audio(snt_id[0], noisy, clean, predictions[0]) + + return loss.detach() + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + # Define function taking (prediction, target) for parallel eval + def pesq_eval(pred_wav, target_wav): + """Computes the PESQ evaluation metric""" + psq_mode = "wb" if self.hparams.sample_rate == 16000 else "nb" + try: + return pesq( + fs=self.hparams.sample_rate, + ref=target_wav.numpy(), + deg=pred_wav.numpy(), + mode=psq_mode, + ) + except Exception: + print("pesq encountered an error for this data item") + return 0 + + self.pesq_metric = MetricStats( + metric=pesq_eval, n_jobs=1, batch_eval=False + ) + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of a epoch.""" + # Compute/store important stats + stage_stats = {"si-snr": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stats = { + "si-snr": stage_loss, + "pesq": self.pesq_metric.summarize("average"), + } + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + # Save valid logs in TensorBoard + valid_stats = { + "Epochs": epoch, + "Valid SI-SNR": stage_loss, + "Valid PESQ": self.pesq_metric.summarize("average"), + } + if self.hparams.use_tensorboard: + self.hparams.tensorboard_train_logger.log_stats(valid_stats) + + # Learning rate annealing + if isinstance( + self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau + ): + current_lr, next_lr = self.hparams.lr_scheduler( + [self.optimizer], epoch, stage_loss + ) + schedulers.update_learning_rate(self.optimizer, next_lr) + else: + # if we do not use the reducelronplateau, we do not change the lr + current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"] + + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr": current_lr}, + train_stats=self.train_stats, + valid_stats=stats, + ) + if ( + hasattr(self.hparams, "save_all_checkpoints") + and self.hparams.save_all_checkpoints + ): + self.checkpointer.save_checkpoint(meta={"pesq": stats["pesq"]}) + else: + self.checkpointer.save_and_keep_only( + meta={"pesq": stats["pesq"]}, max_keys=["pesq"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stats, + ) + + def add_speed_perturb(self, clean, targ_lens): + """Adds speed perturbation and random_shift to the input signals""" + + min_len = -1 + recombine = False + + if self.hparams.use_speedperturb: + # Performing speed change (independently on each source) + new_clean = [] + recombine = True + + for i in range(clean.shape[-1]): + new_target = self.hparams.speedperturb( + clean[:, :, i], targ_lens + ) + new_clean.append(new_target) + if i == 0: + min_len = new_target.shape[-1] + else: + if new_target.shape[-1] < min_len: + min_len = new_target.shape[-1] + + if self.hparams.use_rand_shift: + # Performing random_shift (independently on each source) + recombine = True + for i in range(clean.shape[-1]): + rand_shift = torch.randint( + self.hparams.min_shift, self.hparams.max_shift, (1,) + ) + new_clean[i] = new_clean[i].to(self.device) + new_clean[i] = torch.roll( + new_clean[i], shifts=(rand_shift[0],), dims=1 + ) + + # Re-combination + if recombine: + if self.hparams.use_speedperturb: + clean = torch.zeros( + clean.shape[0], + min_len, + clean.shape[-1], + device=clean.device, + dtype=torch.float, + ) + for i, new_target in enumerate(new_clean): + clean[:, :, i] = new_clean[i][:, 0:min_len] + + noisy = clean.sum(-1) + return noisy, clean + + def cut_signals(self, noisy, clean): + """This function selects a random segment of a given length withing the noisy. + The corresponding clean are selected accordingly""" + randstart = torch.randint( + 0, + 1 + max(0, noisy.shape[1] - self.hparams.training_signal_len), + (1,), + ).item() + clean = clean[ + :, randstart : randstart + self.hparams.training_signal_len, : + ] + noisy = noisy[ + :, randstart : randstart + self.hparams.training_signal_len + ] + return noisy, clean + + def reset_layer_recursively(self, layer): + """Reinitializes the parameters of the neural networks""" + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + for child_layer in layer.modules(): + if layer != child_layer: + self.reset_layer_recursively(child_layer) + + def save_results(self, valid_data): + """This script calculates the SDR and SI-SNR metrics + and stores them in a CSV file. As this evaluation + method depends on a gold-standard reference signal, + it is applied exclusively to the valid set and excludes + the baseline data. + """ + # This package is required for SDR computation + from mir_eval.separation import bss_eval_sources + + # Create folders where to store audio + save_file = os.path.join( + self.hparams.output_folder, "valid_results.csv" + ) + + # Variable init + all_sdrs = [] + all_sdrs_i = [] + all_sisnrs = [] + all_sisnrs_i = [] + all_pesqs = [] + all_stois = [] + all_csigs = [] + all_cbaks = [] + all_covls = [] + csv_columns = [ + "snt_id", + "sdr", + "sdr_i", + "si-snr", + "si-snr_i", + "pesq", + "stoi", + "csig", + "cbak", + "covl", + ] + + valid_loader = sb.dataio.dataloader.make_dataloader( + valid_data, **self.hparams.dataloader_opts_test + ) + + with open(save_file, "w") as results_csv: + writer = csv.DictWriter(results_csv, fieldnames=csv_columns) + writer.writeheader() + + # Loop over all test sentence + with tqdm(valid_loader, dynamic_ncols=True) as t: + for i, batch in enumerate(t): + # Apply Enhancement + noisy, noisy_len = batch.noisy_sig + snt_id = batch.id + clean = batch.clean_sig + + with torch.no_grad(): + predictions, clean = self.compute_forward( + batch.noisy_sig, clean, sb.Stage.TEST + ) + + # Compute PESQ + psq_mode = ( + "wb" if self.hparams.sample_rate == 16000 else "nb" + ) + + try: + # Compute SI-SNR + sisnr = self.compute_objectives(predictions, clean) + + # Compute SI-SNR improvement + noisy_signal = noisy + + noisy_signal = noisy_signal.to(clean.device) + sisnr_baseline = self.compute_objectives( + [noisy_signal.squeeze(-1), None], clean + ) + sisnr_i = sisnr - sisnr_baseline + + # Compute SDR + sdr, _, _, _ = bss_eval_sources( + clean[0].t().cpu().numpy(), + predictions[0][0].t().detach().cpu().numpy(), + ) + + sdr_baseline, _, _, _ = bss_eval_sources( + clean[0].t().cpu().numpy(), + noisy_signal[0].t().detach().cpu().numpy(), + ) + + sdr_i = sdr.mean() - sdr_baseline.mean() + + # Compute PESQ + psq = pesq( + self.hparams.sample_rate, + clean.squeeze().cpu().numpy(), + predictions[0].squeeze().cpu().numpy(), + mode=psq_mode, + ) + # Compute STOI + stoi_score = stoi( + clean.squeeze().cpu().numpy(), + predictions[0].squeeze().cpu().numpy(), + fs_sig=self.hparams.sample_rate, + extended=False, + ) + # Compute CSIG, CBAK, COVL + composite_metrics = eval_composite( + clean.squeeze().cpu().numpy(), + predictions[0].squeeze().cpu().numpy(), + self.hparams.sample_rate, + ) + except Exception: + # this handles all sorts of error that may + # occur when evaluating an enhanced file. + continue + + # Saving on a csv file + row = { + "snt_id": snt_id[0], + "sdr": sdr.mean(), + "sdr_i": sdr_i, + "si-snr": -sisnr.item(), + "si-snr_i": -sisnr_i.item(), + "pesq": psq, + "stoi": stoi_score, + "csig": composite_metrics["csig"], + "cbak": composite_metrics["cbak"], + "covl": composite_metrics["covl"], + } + writer.writerow(row) + + # Metric Accumulation + all_sdrs.append(sdr.mean()) + all_sdrs_i.append(sdr_i.mean()) + all_sisnrs.append(-sisnr.item()) + all_sisnrs_i.append(-sisnr_i.item()) + all_pesqs.append(psq) + all_stois.append(stoi_score) + all_csigs.append(composite_metrics["csig"]) + all_cbaks.append(composite_metrics["cbak"]) + all_covls.append(composite_metrics["covl"]) + + row = { + "snt_id": "avg", + "sdr": np.array(all_sdrs).mean(), + "sdr_i": np.array(all_sdrs_i).mean(), + "si-snr": np.array(all_sisnrs).mean(), + "si-snr_i": np.array(all_sisnrs_i).mean(), + "pesq": np.array(all_pesqs).mean(), + "stoi": np.array(all_stois).mean(), + "csig": np.array(all_csigs).mean(), + "cbak": np.array(all_cbaks).mean(), + "covl": np.array(all_covls).mean(), + } + writer.writerow(row) + + logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean())) + logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean())) + logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean())) + logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean())) + logger.info("Mean PESQ {}".format(np.array(all_pesqs).mean())) + logger.info("Mean STOI {}".format(np.array(all_stois).mean())) + logger.info("Mean CSIG {}".format(np.array(all_csigs).mean())) + logger.info("Mean CBAK {}".format(np.array(all_cbaks).mean())) + logger.info("Mean COVL {}".format(np.array(all_covls).mean())) + + def save_audio(self, snt_id, noisy, clean, predictions): + "saves the test audio (noisy, clean, and estimated sources) on disk" + print("Saving enhanced sources (valid set)") + + # Create output folders + save_path = os.path.join( + self.hparams.save_folder, "valid_audio_results" + ) + save_path_enhanced = os.path.join(save_path, "enhanced_sources") + save_path_clean = os.path.join(save_path, "clean_sources") + save_path_noisy = os.path.join(save_path, "noisy_sources") + + for path in [save_path_enhanced, save_path_clean, save_path_noisy]: + if not os.path.exists(path): + os.makedirs(path) + + # Estimated source + signal = predictions[0, :] + signal = signal / signal.abs().max() + save_file = os.path.join( + save_path_enhanced, "item{}_sourcehat.wav".format(snt_id) + ) + torchaudio.save( + save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate + ) + + # Original source + signal = clean[0, :] + signal = signal / signal.abs().max() + save_file = os.path.join( + save_path_clean, "item{}_source.wav".format(snt_id) + ) + torchaudio.save( + save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate + ) + + # Noisy source + signal = noisy[0][0, :] + signal = signal / signal.abs().max() + save_file = os.path.join( + save_path_noisy, "item{}_noisy.wav".format(snt_id) + ) + torchaudio.save( + save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate + ) + + +def dataio_prep(hparams): + """Creates data processing pipeline""" + speech_dirs = [ + "read_speech", + "german_speech", + "french_speech", + "italian_speech", + "spanish_speech", + "russian_speech", + ] + audio_length = hparams["audio_length"] + + train_shard_patterns = [] + for dir in speech_dirs: + if not os.path.exists(os.path.join(hparams["train_data"], dir)): + dir = "" + shard_pattern = os.path.join(hparams["train_data"], dir, "shard-*.tar") + shard_files = glob.glob(shard_pattern) + train_shard_patterns.extend(shard_files) + + valid_shard_patterns = [] + for dir in speech_dirs: + if not os.path.exists(os.path.join(hparams["valid_data"], dir)): + dir = "" + shard_pattern = os.path.join(hparams["valid_data"], dir, "shard-*.tar") + shard_files = glob.glob(shard_pattern) + valid_shard_patterns.extend(shard_files) + + def meta_loader(split_path): + # Initialize the total number of samples + total_samples = 0 + + # Walk through the all subdirs + # eg. german_speech, read_speech, ... + for root, _, files in os.walk(split_path): + for file in files: + if file == "meta.json": + meta_json_path = os.path.join(root, file) + with open(meta_json_path, "rb") as f: + meta = json.load(f) + total_samples += meta.get("num_data_samples", 0) + + return total_samples + + def train_audio_pipeline(sample_dict: Dict, random_chunk=True): + key = sample_dict["__key__"] + clean_wav = sample_dict["clean_file"] + noise_wav = sample_dict["noise_file"] + noisy_wav = sample_dict["noisy_file"] + clean_sig = sample_dict["clean_audio.pth"].squeeze() + noise_sig = sample_dict["noise_audio.pth"].squeeze() + noisy_sig = sample_dict["noisy_audio.pth"].squeeze() + + return { + "id": key, + "clean_wav": clean_wav, + "clean_sig": clean_sig, + "noise_wav": noise_wav, + "noise_sig": noise_sig, + "noisy_wav": noisy_wav, + "noisy_sig": noisy_sig, + } + + def baseline_audio_pipeline(sample_dict: Dict, random_chunk=True): + key = sample_dict["__key__"] + noisy_sig = sample_dict["audio.pth"].squeeze() + + return { + "id": key, + "noisy_wav": key, + "noisy_sig": noisy_sig, + } + + def create_combined_dataset(shard_patterns, cache_dir): + # mix multiple datasets, where each dataset consists of multiple shards + # e.g. combine read_speech, german_speech etc. each with multiple shards. + urls = [ + url + for shard in shard_patterns + for url in braceexpand.braceexpand(shard) + ] + + combined_dataset = ( + wds.WebDataset(urls, shardshuffle=True, cache_dir=cache_dir,) + .repeat() + .shuffle(1000) + .decode("pil") + .map(partial(train_audio_pipeline, random_chunk=True)) + ) + + return combined_dataset + + train_data = create_combined_dataset( + train_shard_patterns, hparams["shard_cache_dir"] + ) + train_samples = meta_loader(hparams["train_data"]) + logger.info(f"Training data- Number of samples: {train_samples}") + logger.info( + f"Training data - Total duration: {train_samples * audio_length/ 3600:.2f} hours" + ) + + valid_data = create_combined_dataset( + valid_shard_patterns, hparams["shard_cache_dir"] + ) + valid_samples = meta_loader(hparams["valid_data"]) + logger.info(f"Valid data- Number of samples: {valid_samples}") + logger.info( + f"Valid data - Total duration: {valid_samples * audio_length / 3600:.2f} hours" + ) + + baseline_data = ( + wds.WebDataset( + hparams["baseline_shards"], cache_dir=hparams["shard_cache_dir"], + ) + .repeat() + .shuffle(1000) + .decode("pil") + .map(partial(baseline_audio_pipeline, random_chunk=True)) + ) + + return train_data, valid_data, train_samples, valid_samples, baseline_data + + +if __name__ == "__main__": + # Load hyperparameters file with command-line overrides + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Initialize ddp (useful only for multi-GPU DDP training) + sb.utils.distributed.ddp_init_group(run_opts) + + # Logger info + logger = logging.getLogger(__name__) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + if hparams["use_tensorboard"]: + from speechbrain.utils.train_logger import TensorboardLogger + + hparams["tensorboard_train_logger"] = TensorboardLogger( + hparams["tensorboard_logs"] + ) + + ( + train_data, + valid_data, + num_train_samples, + num_valid_samples, + baseline_data, + ) = dataio_prep(hparams) + + # add collate_fn to dataloader options + hparams["dataloader_opts"]["collate_fn"] = PaddedBatch + hparams["dataloader_opts_valid"]["collate_fn"] = PaddedBatch + hparams["dataloader_opts_test"]["collate_fn"] = PaddedBatch + + hparams["dataloader_opts"]["looped_nominal_epoch"] = ( + num_train_samples // hparams["dataloader_opts"]["batch_size"] + ) + hparams["dataloader_opts_valid"]["looped_nominal_epoch"] = ( + num_valid_samples // hparams["dataloader_opts_valid"]["batch_size"] + ) + hparams["dataloader_opts_test"]["looped_nominal_epoch"] = ( + num_valid_samples // hparams["dataloader_opts_test"]["batch_size"] + ) + + # Load pretrained model if pretrained_enhancement is present in the yaml + if "pretrained_enhancement" in hparams: + run_on_main(hparams["pretrained_enhancement"].collect_files) + hparams["pretrained_enhancement"].load_collected() + + # Brain class initialization + enhancement = Enhancement( + modules=hparams["modules"], + opt_class=hparams["optimizer"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # re-initialize the parameters if we don't use a pretrained model + if "pretrained_enhancement" not in hparams: + for module in enhancement.modules.values(): + enhancement.reset_layer_recursively(module) + + # determine if frequency domain enhancement or not + use_freq_domain = hparams.get("use_freq_domain", False) + enhancement.use_freq_domain = use_freq_domain + + if not hparams["test_only"]: + # Training + enhancement.fit( + enhancement.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["dataloader_opts"], + valid_loader_kwargs=hparams["dataloader_opts_valid"], + ) + + ## Evaluation on valid data + # (since our test set is blind) + enhancement.evaluate( + valid_data, + max_key="pesq", + test_loader_kwargs=hparams["dataloader_opts_valid"], + ) + enhancement.save_results(valid_data) + + ## Save enhanced sources of baseline noisy testclips + def save_baseline_audio(snt_id, predictions): + "saves the estimated sources on disk" + # Create outout folder + save_path = os.path.join( + hparams["save_folder"], "baseline_audio_results" + ) + save_path_enhanced = os.path.join(save_path, "enhanced_testclips") + + if not os.path.exists(save_path_enhanced): + os.makedirs(save_path_enhanced) + + # Estimated source + signal = predictions[0, :] + signal = signal / signal.abs().max() + save_file = os.path.join(save_path_enhanced, snt_id) + ".wav" + + torchaudio.save( + save_file, signal.unsqueeze(0).cpu(), hparams["sample_rate"] + ) + + test_loader = sb.dataio.dataloader.make_dataloader( + baseline_data, **hparams["dataloader_opts_test"] + ) + + # Loop over all noisy baseline shards and save the enahanced clips + print("Saving enhanced sources (baseline set)") + with tqdm(test_loader, dynamic_ncols=True) as t: + for i, batch in enumerate(t): + # Apply Enhancement + snt_id = batch.id[0] + + with torch.no_grad(): + # Since only noisy sources are provided for baseline + # we use the compute_forward function with the same noisy + # signal for all inputs. (ugly hack) + predictions, clean = enhancement.compute_forward( + batch.noisy_sig, + batch.noisy_sig, + batch.noisy_sig, + sb.Stage.TEST, + ) + predictions = predictions[0] + + # Write enhanced wavs + save_baseline_audio(snt_id, predictions) diff --git a/recipes/DNS/noisyspeech_synthesizer/README.md b/recipes/DNS/noisyspeech_synthesizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b4c3e6870ff35054006a6b1d9a6b88dd15a44c19 --- /dev/null +++ b/recipes/DNS/noisyspeech_synthesizer/README.md @@ -0,0 +1,34 @@ +# **DNS: Noisy speech synthesizer** +This folder contains scripts to synthesize noisy audio for training. +Scripts have been taken from [official GitHub repo](https://github.com/microsoft/DNS-Challenge). + +Modify parameters like `sampling_rate`, `audio_length` , `total_hours` etc in the YAML file as per your requirement. + +## Synthesize clean-noisy data and create the Webdataset shards +Synthesize clean-noisy data and create WebDataset shards. + +### **Usage** +To create noisy dataset, run +``` +## synthesize read speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name read_speech --synthesized_data_dir synthesized_data_shards + +## synthesize German speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name german_speech --synthesized_data_dir synthesized_data_shards + +## synthesize Italian speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name italian_speech --synthesized_data_dir synthesized_data_shards + +## synthesize French speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name french_speech --synthesized_data_dir synthesized_data_shards + +## synthesize Spanish speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name spanish_speech --synthesized_data_dir synthesized_data_shards + +## synthesize Russian speech +python noisyspeech_synthesizer_singleprocess.py noisyspeech_synthesizer.yaml --input_shards_dir ../DNS-shards --split_name russian_speech --synthesized_data_dir synthesized_data_shards +``` + +It's recommended to execute these commands in parallel for quicker synthesis. + +**Time** : It takes about 140 HRS to synthesize a dataset of 500 HRS. This calls the need for dynamic mixing. diff --git a/recipes/DNS/noisyspeech_synthesizer/audiolib.py b/recipes/DNS/noisyspeech_synthesizer/audiolib.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1787923ce349a5972a34419831e6c4e9c24894 --- /dev/null +++ b/recipes/DNS/noisyspeech_synthesizer/audiolib.py @@ -0,0 +1,352 @@ +""" +Source: https://github.com/microsoft/DNS-Challenge +Ownership: Microsoft + +* Author + chkarada +""" + +import os +import numpy as np +import soundfile as sf +import subprocess +import glob +import librosa + +EPS = np.finfo(float).eps +np.random.seed(0) + + +def is_clipped(audio, clipping_threshold=0.99): + """Check if an audio signal is clipped. + """ + return any(abs(audio) > clipping_threshold) + + +def normalize(audio, target_level=-25): + """Normalize the signal to the target level""" + rms = (audio ** 2).mean() ** 0.5 + scalar = 10 ** (target_level / 20) / (rms + EPS) + audio = audio * scalar + return audio + + +def normalize_segmental_rms(audio, rms, target_level=-25): + """Normalize the signal to the target level + based on segmental RMS""" + scalar = 10 ** (target_level / 20) / (rms + EPS) + audio = audio * scalar + return audio + + +def audioread(path, norm=False, start=0, stop=None, target_level=-25): + """Function to read audio""" + + path = os.path.abspath(path) + if not os.path.exists(path): + raise ValueError("[{}] does not exist!".format(path)) + try: + audio, sample_rate = sf.read(path, start=start, stop=stop) + except RuntimeError: # fix for sph pcm-embedded shortened v2 + print("WARNING: Audio type not supported") + return (None, None) + + if len(audio.shape) == 1: # mono + if norm: + rms = (audio ** 2).mean() ** 0.5 + scalar = 10 ** (target_level / 20) / (rms + EPS) + audio = audio * scalar + else: # multi-channel + audio = audio.T + audio = audio.sum(axis=0) / audio.shape[0] + if norm: + audio = normalize(audio, target_level) + + return audio, sample_rate + + +def audiowrite( + destpath, + audio, + sample_rate=16000, + norm=False, + target_level=-25, + clipping_threshold=0.99, + clip_test=False, +): + """Function to write audio""" + + if clip_test: + if is_clipped(audio, clipping_threshold=clipping_threshold): + raise ValueError( + "Clipping detected in audiowrite()! " + + destpath + + " file not written to disk." + ) + + if norm: + audio = normalize(audio, target_level) + max_amp = max(abs(audio)) + if max_amp >= clipping_threshold: + audio = audio / max_amp * (clipping_threshold - EPS) + + destpath = os.path.abspath(destpath) + destdir = os.path.dirname(destpath) + + if not os.path.exists(destdir): + os.makedirs(destdir) + + sf.write(destpath, audio, sample_rate) + return + + +def add_reverb(sasxExe, input_wav, filter_file, output_wav): + """ Function to add reverb""" + command_sasx_apply_reverb = "{0} -r {1} \ + -f {2} -o {3}".format( + sasxExe, input_wav, filter_file, output_wav + ) + + subprocess.call(command_sasx_apply_reverb) + return output_wav + + +def add_clipping(audio, max_thresh_perc=0.8): + """Function to add clipping""" + threshold = max(abs(audio)) * max_thresh_perc + audioclipped = np.clip(audio, -threshold, threshold) + return audioclipped + + +def adsp_filter(Adspvqe, nearEndInput, nearEndOutput, farEndInput): + + command_adsp_clean = "{0} --breakOnErrors 0 --sampleRate 16000 --useEchoCancellation 0 \ + --operatingMode 2 --useDigitalAgcNearend 0 --useDigitalAgcFarend 0 \ + --useVirtualAGC 0 --useComfortNoiseGenerator 0 --useAnalogAutomaticGainControl 0 \ + --useNoiseReduction 0 --loopbackInputFile {1} --farEndInputFile {2} \ + --nearEndInputFile {3} --nearEndOutputFile {4}".format( + Adspvqe, farEndInput, farEndInput, nearEndInput, nearEndOutput + ) + subprocess.call(command_adsp_clean) + + +def snr_mixer( + params, clean, noise, snr, target_level=-25, clipping_threshold=0.99 +): + """Function to mix clean speech and noise at various SNR levels""" + # cfg = params['cfg'] + if len(clean) > len(noise): + noise = np.append(noise, np.zeros(len(clean) - len(noise))) + else: + clean = np.append(clean, np.zeros(len(noise) - len(clean))) + + # Normalizing to -25 dB FS + clean = clean / (max(abs(clean)) + EPS) + clean = normalize(clean, target_level) + rmsclean = (clean ** 2).mean() ** 0.5 + + noise = noise / (max(abs(noise)) + EPS) + noise = normalize(noise, target_level) + rmsnoise = (noise ** 2).mean() ** 0.5 + + # Set the noise level for a given SNR + noisescalar = rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS) + noisenewlevel = noise * noisescalar + + # Mix noise and clean speech + noisyspeech = clean + noisenewlevel + + # Randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value + # There is a chance of clipping that might happen with very less probability, which is not a major issue. + noisy_rms_level = np.random.randint( + params["target_level_lower"], params["target_level_upper"] + ) + rmsnoisy = (noisyspeech ** 2).mean() ** 0.5 + scalarnoisy = 10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS) + noisyspeech = noisyspeech * scalarnoisy + clean = clean * scalarnoisy + noisenewlevel = noisenewlevel * scalarnoisy + + # Final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly + if is_clipped(noisyspeech): + noisyspeech_maxamplevel = max(abs(noisyspeech)) / ( + clipping_threshold - EPS + ) + noisyspeech = noisyspeech / noisyspeech_maxamplevel + clean = clean / noisyspeech_maxamplevel + noisenewlevel = noisenewlevel / noisyspeech_maxamplevel + noisy_rms_level = int( + 20 + * np.log10(scalarnoisy / noisyspeech_maxamplevel * (rmsnoisy + EPS)) + ) + + return clean, noisenewlevel, noisyspeech, noisy_rms_level + + +def segmental_snr_mixer( + params, clean, noise, snr, target_level=-25, clipping_threshold=0.99 +): + """Function to mix clean speech and noise at various segmental SNR levels""" + # cfg = params['cfg'] + if len(clean) > len(noise): + noise = np.append(noise, np.zeros(len(clean) - len(noise))) + else: + clean = np.append(clean, np.zeros(len(noise) - len(clean))) + clean = clean / (max(abs(clean)) + EPS) + noise = noise / (max(abs(noise)) + EPS) + rmsclean, rmsnoise = active_rms(clean=clean, noise=noise) + clean = normalize_segmental_rms( + clean, rms=rmsclean, target_level=target_level + ) + noise = normalize_segmental_rms( + noise, rms=rmsnoise, target_level=target_level + ) + # Set the noise level for a given SNR + noisescalar = rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS) + noisenewlevel = noise * noisescalar + + # Mix noise and clean speech + noisyspeech = clean + noisenewlevel + # Randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value + # There is a chance of clipping that might happen with very less probability, which is not a major issue. + noisy_rms_level = np.random.randint( + params["target_level_lower"], params["target_level_upper"] + ) + rmsnoisy = (noisyspeech ** 2).mean() ** 0.5 + scalarnoisy = 10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS) + noisyspeech = noisyspeech * scalarnoisy + clean = clean * scalarnoisy + noisenewlevel = noisenewlevel * scalarnoisy + # Final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly + if is_clipped(noisyspeech): + noisyspeech_maxamplevel = max(abs(noisyspeech)) / ( + clipping_threshold - EPS + ) + noisyspeech = noisyspeech / noisyspeech_maxamplevel + clean = clean / noisyspeech_maxamplevel + noisenewlevel = noisenewlevel / noisyspeech_maxamplevel + noisy_rms_level = int( + 20 + * np.log10(scalarnoisy / noisyspeech_maxamplevel * (rmsnoisy + EPS)) + ) + + return clean, noisenewlevel, noisyspeech, noisy_rms_level + + +def active_rms(clean, noise, fs=16000, energy_thresh=-50): + """Returns the clean and noise RMS of the noise calculated only in the active portions""" + window_size = 100 # in ms + window_samples = int(fs * window_size / 1000) + sample_start = 0 + noise_active_segs = [] + clean_active_segs = [] + + while sample_start < len(noise): + sample_end = min(sample_start + window_samples, len(noise)) + noise_win = noise[sample_start:sample_end] + clean_win = clean[sample_start:sample_end] + noise_seg_rms = (noise_win ** 2).mean() ** 0.5 + # Considering frames with energy + if noise_seg_rms > energy_thresh: + noise_active_segs = np.append(noise_active_segs, noise_win) + clean_active_segs = np.append(clean_active_segs, clean_win) + sample_start += window_samples + + if len(noise_active_segs) != 0: + noise_rms = (noise_active_segs ** 2).mean() ** 0.5 + else: + noise_rms = EPS + + if len(clean_active_segs) != 0: + clean_rms = (clean_active_segs ** 2).mean() ** 0.5 + else: + clean_rms = EPS + + return clean_rms, noise_rms + + +def activitydetector(audio, fs=16000, energy_thresh=0.13, target_level=-25): + """Return the percentage of the time the audio signal is above an energy threshold""" + + audio = normalize(audio, target_level) + window_size = 50 # in ms + window_samples = int(fs * window_size / 1000) + sample_start = 0 + cnt = 0 + prev_energy_prob = 0 + active_frames = 0 + + a = -1 + b = 0.2 + alpha_rel = 0.05 + alpha_att = 0.8 + + while sample_start < len(audio): + sample_end = min(sample_start + window_samples, len(audio)) + audio_win = audio[sample_start:sample_end] + frame_rms = 20 * np.log10(sum(audio_win ** 2) + EPS) + frame_energy_prob = 1.0 / (1 + np.exp(-(a + b * frame_rms))) + + if frame_energy_prob > prev_energy_prob: + smoothed_energy_prob = ( + frame_energy_prob * alpha_att + + prev_energy_prob * (1 - alpha_att) + ) + else: + smoothed_energy_prob = ( + frame_energy_prob * alpha_rel + + prev_energy_prob * (1 - alpha_rel) + ) + + if smoothed_energy_prob > energy_thresh: + active_frames += 1 + prev_energy_prob = frame_energy_prob + sample_start += window_samples + cnt += 1 + + perc_active = active_frames / cnt + return perc_active + + +def resampler(input_dir, target_sr=16000, ext="*.wav"): + """Resamples the audio files in input_dir to target_sr""" + files = glob.glob(f"{input_dir}/" + ext) + for pathname in files: + print(pathname) + try: + audio, fs = audioread(pathname) + audio_resampled = librosa.core.resample(audio, fs, target_sr) + audiowrite(pathname, audio_resampled, target_sr) + except: # noqa + continue + + +def audio_segmenter(input_dir, dest_dir, segment_len=10, ext="*.wav"): + """Segments the audio clips in dir to segment_len in secs""" + files = glob.glob(f"{input_dir}/" + ext) + for i in range(len(files)): + audio, fs = audioread(files[i]) + + if ( + len(audio) > (segment_len * fs) + and len(audio) % (segment_len * fs) != 0 + ): + audio = np.append( + audio, + audio[0 : segment_len * fs - (len(audio) % (segment_len * fs))], + ) + if len(audio) < (segment_len * fs): + while len(audio) < (segment_len * fs): + audio = np.append(audio, audio) + audio = audio[: segment_len * fs] + + num_segments = int(len(audio) / (segment_len * fs)) + audio_segments = np.split(audio, num_segments) + + basefilename = os.path.basename(files[i]) + basename, ext = os.path.splitext(basefilename) + + for j in range(len(audio_segments)): + newname = basename + "_" + str(j) + ext + destpath = os.path.join(dest_dir, newname) + audiowrite(destpath, audio_segments[j], fs) diff --git a/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e3073702f310c314fbafd27c614e1e246a96c4 --- /dev/null +++ b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml @@ -0,0 +1,101 @@ +# yamllint disable +################################ +# Configuration for generating Noisy Speech Dataset +# - sampling_rate: Specify the sampling rate. Default is 16 kHz +# - audioformat: default is .wav +# - audio_length: Minimum Length of each audio clip (noisy and clean speech) +# in seconds that will be generated by augmenting utterances. +# - silence_length: Duration of silence introduced between clean speech +# utterances. +# - total_hours: Total number of hours of data required. Units are in hours. +# - snr_lower: Lower bound for SNR required (default: 0 dB) +# - snr_upper: Upper bound for SNR required (default: 40 dB) +# - target_level_lower: Lower bound for the target audio level +# before audiowrite (default: -35 dB) +# - target_level_upper: Upper bound for the target audio level +# before audiowrite (default: -15 dB) +# - total_snrlevels: Number of SNR levels required (default: 5, which means +# there are 5 levels between snr_lower and snr_upper) +# - clean_activity_threshold: Activity threshold for clean speech +# - noise_activity_threshold: Activity threshold for noise +# - fileindex_start: Starting file ID that will be used in filenames +# - fileindex_end: Last file ID that will be used in filenames +# - is_test_set: Set it to True if it is the test set, else False for the +# - log_dir: Specify path to the directory to store all the log files +# ################################ +# yamllint enable + + +# Data storage params +input_shards_dir: !PLACEHOLDER # ../DNS-shards +split_name: !PLACEHOLDER # read_speech, german_speech, italian_speech, french_speech etc +rirs: RIR_table_simple.csv + +# Noisy data synthesis params +sampling_rate: 16000 # sampling rate of synthesized signal +audioformat: "*.wav" +audio_length: 4 +silence_length: 0.2 +total_hours: 100 +snr_lower: -5 +snr_upper: 15 +randomize_snr: True +target_level_lower: -35 +target_level_upper: -15 +total_snrlevels: 21 +clean_activity_threshold: 0.6 +noise_activity_threshold: 0.0 +fileindex_start: None +fileindex_end: None +is_test_set: False + +# Source dir +rir_table_csv: !ref <rirs> + +# Directory path where Webdatasets of DNS clean and noise shards are located. +input_sampling_rate: 48000 # sampling rate of input signal +clean_meta: !ref <input_shards_dir>/clean_fullband/<split_name>/meta.json +noise_meta: !ref <input_shards_dir>/noise_fullband/meta.json +clean_fullband_shards: !ref <input_shards_dir>/clean_fullband/<split_name>/shard-{000000..999999}.tar +noise_fullband_shards: !ref <input_shards_dir>/noise_fullband/shard-{000000..999999}.tar + +# Configuration for synthesizing shards of clean-noisy pairs. +samples_per_shard: 5000 + +# Destination directory for storing shards of synthesized data. +synthesized_data_dir: !PLACEHOLDER # synthesized_data_shards +train_shard_destination: !ref <synthesized_data_dir>/train_shards/<split_name> +valid_shard_destination: !ref <synthesized_data_dir>/valid_shards/<split_name> + +# Set to a directory on a large disk if using Webdataset shards hosted on the web. +shard_cache_dir: + +# These can be skipped. (uncomment if you want to use them) +# clean_singing: !PLACEHOLDER # ../DNS-shards/clean_fullband/VocalSet_48kHz_mono/ +# clean_emotion: !PLACEHOLDER # ../DNS-shards/clean_fullband/emotional_speech/ +## Aishell data needs to be downloaded separately. +# clean_mandarin: !PLACEHOLDER # ../DNS-shards/clean_fullband/mandrin_speech/data_aishell + +log_dir: !ref <split_name>_logs +noise_types_excluded: None + +## Config: add singing voice to clean speech +use_singing_data: 0 # 0 for no, 1 for yes +# 1 for only male, 2 for only female, 3 (default) for both male and female +singing_choice: 3 + +## Config: add emotional data to clean speech +# 0 for no, 1 for yes +use_emotion_data: 0 + +## Config: add Chinese (mandarin) data to clean speech +# 0 for no, 1 for yes +use_mandarin_data: 0 + +## Config: add reverb to clean speech +# 1 for only real rir, 2 for only synthetic rir, 3 (default) use both real and synthetic +rir_choice: 3 +# lower bound of t60 range in seconds +lower_t60: 0.3 +# upper bound of t60 range in seconds +upper_t60: 1.3 diff --git a/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..478c786d054a5e3e5f832476802c4e27a11c2ab1 --- /dev/null +++ b/recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer_singleprocess.py @@ -0,0 +1,720 @@ +""" +Source: https://github.com/microsoft/DNS-Challenge +Ownership: Microsoft + +This script will attempt to use each clean and noise +webdataset shards to synthesize clean-noisy pairs of +audio. The output is again stored in webdataset shards. + +* Author + chkarada + +* Further modified + Sangeet Sagar (2023) +""" + +# Note: This single process audio synthesizer will attempt to use each clean +# speech sourcefile once (from the webdataset shards), as it does not +# randomly sample from these files + +import sys +import os +from pathlib import Path +import random +import time + +import numpy as np +from scipy import signal +from scipy.io import wavfile + +import librosa + +import utils +from audiolib import ( + segmental_snr_mixer, + activitydetector, + is_clipped, +) + +import pandas as pd +import json +from functools import partial +from typing import Dict +from collections import defaultdict + + +import speechbrain as sb +import webdataset as wds +from hyperpyyaml import load_hyperpyyaml +import torch + +np.random.seed(5) +random.seed(5) + +MAXTRIES = 50 +MAXFILELEN = 100 + +start = time.time() + + +def add_pyreverb(clean_speech, rir): + """ + Add reverb to cean signal + """ + reverb_speech = signal.fftconvolve(clean_speech, rir, mode="full") + + # make reverb_speech same length as clean_speech + reverb_speech = reverb_speech[0 : clean_speech.shape[0]] + + return reverb_speech + + +def build_audio(is_clean, params, index, audio_samples_length=-1): + """Construct an audio signal from source files""" + + fs_output = params["fs"] + silence_length = params["silence_length"] + if audio_samples_length == -1: + audio_samples_length = int(params["audio_length"] * params["fs"]) + + output_audio = np.zeros(0) + remaining_length = audio_samples_length + files_used = [] + clipped_files = [] + + if is_clean: + data_iterator = iter(params["clean_data"]) + idx = index + else: + data_iterator = iter(params["noise_data"]) + idx = index + + # initialize silence + silence = np.zeros(int(fs_output * silence_length)) + + # iterate through multiple clips until we have a long enough signal + tries_left = MAXTRIES + while remaining_length > 0 and tries_left > 0: + # read next audio file and resample if necessary + fs_input = params["fs_input"] + batch = next(data_iterator) + input_audio = batch["sig"].numpy() + + if input_audio is None: + sys.stderr.write( + "\nWARNING: Cannot read file: %s\n" % batch["__key__"] + ) + continue + if fs_input != fs_output: + input_audio = librosa.resample( + input_audio, orig_sr=fs_input, target_sr=fs_output + ) + + # if current file is longer than remaining desired length, and this is + # noise generation or this is training set, subsample it randomly + if len(input_audio) > remaining_length and ( + not is_clean or not params["is_test_set"] + ): + idx_seg = np.random.randint(0, len(input_audio) - remaining_length) + input_audio = input_audio[idx_seg : idx_seg + remaining_length] + + # check for clipping, and if found move onto next file + if is_clipped(input_audio): + clipped_files.append(batch["__key__"]) + tries_left -= 1 + continue + + # concatenate current input audio to output audio stream + files_used.append(batch["__key__"]) + output_audio = np.append(output_audio, input_audio) + remaining_length -= len(input_audio) + + # add some silence if we have not reached desired audio length + if remaining_length > 0: + silence_len = min(remaining_length, len(silence)) + output_audio = np.append(output_audio, silence[:silence_len]) + remaining_length -= silence_len + + if tries_left == 0 and not is_clean and "noise_data" in params.keys(): + print( + "There are not enough non-clipped files in the " + + "given noise directory to complete the audio build" + ) + return [], [], clipped_files, idx + + return output_audio, files_used, clipped_files, idx + + +def gen_audio(is_clean, params, index, audio_samples_length=-1): + """Calls build_audio() to get an audio signal, and verify that it meets the + activity threshold""" + + clipped_files = [] + low_activity_files = [] + if audio_samples_length == -1: + audio_samples_length = int(params["audio_length"] * params["fs"]) + if is_clean: + activity_threshold = params["clean_activity_threshold"] + else: + activity_threshold = params["noise_activity_threshold"] + + while True: + audio, source_files, new_clipped_files, index = build_audio( + is_clean, params, index, audio_samples_length + ) + + clipped_files += new_clipped_files + if len(audio) < audio_samples_length: + continue + + if activity_threshold == 0.0: + break + + percactive = activitydetector(audio=audio) + if percactive > activity_threshold: + break + else: + low_activity_files += source_files + + return audio, source_files, clipped_files, low_activity_files, index + + +def main_gen(params): + """Calls gen_audio() to generate the audio signals, verifies that they meet + the requirements, and writes the files to storage""" + + clean_source_files = [] + clean_clipped_files = [] + clean_low_activity_files = [] + noise_source_files = [] + noise_clipped_files = [] + noise_low_activity_files = [] + + clean_index = 0 + noise_index = 0 + + # write shards + train_shards_path = Path(params["train_shard_destination"]) + train_shards_path.mkdir(exist_ok=True, parents=True) + valid_shards_path = Path(params["valid_shard_destination"]) + valid_shards_path.mkdir(exist_ok=True, parents=True) + + all_keys = set() + train_pattern = str(train_shards_path / "shard") + "-%06d.tar" + valid_pattern = str(valid_shards_path / "shard") + "-%06d.tar" + samples_per_shard = params["samples_per_shard"] + + # track statistics on data + train_sample_keys = defaultdict(list) + valid_sample_keys = defaultdict(list) + + # Define the percentage of data to be used for validation + validation_percentage = 0.05 + + # Calculate the number of samples for training and validation + total_samples = params["fileindex_end"] - params["fileindex_start"] + 1 + num_validation_samples = int(total_samples * validation_percentage) + + # Define separate ShardWriters for training and validation data + train_writer = wds.ShardWriter(train_pattern, maxcount=samples_per_shard) + valid_writer = wds.ShardWriter(valid_pattern, maxcount=samples_per_shard) + + # Initialize counters and data lists for statistics + file_num = params["fileindex_start"] + train_data_tuples = [] + valid_data_tuples = [] + + while file_num <= params["fileindex_end"]: + print( + "\rFiles synthesized {:4d}/{:4d}".format( + file_num, params["fileindex_end"] + ), + end="", + ) + # CLEAN SPEECH GENERATION + clean, clean_sf, clean_cf, clean_laf, clean_index = gen_audio( + True, params, clean_index + ) + + # add reverb with selected RIR + rir_index = random.randint(0, len(params["myrir"]) - 1) + + my_rir = os.path.normpath(os.path.join(params["myrir"][rir_index])) + (fs_rir, samples_rir) = wavfile.read(my_rir) + + my_channel = int(params["mychannel"][rir_index]) + + if samples_rir.ndim == 1: + samples_rir_ch = np.array(samples_rir) + + elif my_channel > 1: + samples_rir_ch = samples_rir[:, my_channel - 1] + else: + samples_rir_ch = samples_rir[:, my_channel - 1] + # print(samples_rir.shape) + # print(my_channel) + + # REVERB MIXED TO THE CLEAN SPEECH + clean = add_pyreverb(clean, samples_rir_ch) + + # generate noise + noise, noise_sf, noise_cf, noise_laf, noise_index = gen_audio( + False, params, noise_index, len(clean) + ) + + clean_clipped_files += clean_cf + clean_low_activity_files += clean_laf + noise_clipped_files += noise_cf + noise_low_activity_files += noise_laf + + # mix clean speech and noise + # if specified, use specified SNR value + if not params["randomize_snr"]: + snr = params["snr"] + # use a randomly sampled SNR value between the specified bounds + else: + snr = np.random.randint(params["snr_lower"], params["snr_upper"]) + + # NOISE ADDED TO THE REVERBED SPEECH + clean_snr, noise_snr, noisy_snr, target_level = segmental_snr_mixer( + params=params, clean=clean, noise=noise, snr=snr + ) + # Uncomment the below lines if you need segmental SNR and comment the above lines using snr_mixer + # clean_snr, noise_snr, noisy_snr, target_level = segmental_snr_mixer(params=params, + # clean=clean, + # noise=noise, + # snr=snr) + # unexpected clipping + if ( + is_clipped(clean_snr) + or is_clipped(noise_snr) + or is_clipped(noisy_snr) + ): + print( + "\nWarning: File #" + + str(file_num) + + " has unexpected clipping, " + + "returning without writing audio to disk" + ) + continue + + clean_source_files += clean_sf + noise_source_files += noise_sf + + # write resultant audio streams to files + hyphen = "-" + clean_source_filenamesonly = [ + i[:-4].split(os.path.sep)[-1] for i in clean_sf + ] + clean_files_joined = hyphen.join(clean_source_filenamesonly)[ + :MAXFILELEN + ] + noise_source_filenamesonly = [ + i[:-4].split(os.path.sep)[-1] for i in noise_sf + ] + noise_files_joined = hyphen.join(noise_source_filenamesonly)[ + :MAXFILELEN + ] + + noisyfilename = ( + clean_files_joined + + "_" + + noise_files_joined + + "_snr" + + str(snr) + + "_tl" + + str(target_level) + + "_fileid_" + + str(file_num) + ) + + # Period is not allowed in a WebDataset key name + cleanfilename = "clean_fileid_" + str(file_num) + cleanfilename = cleanfilename.replace(".", "_") + noisefilename = "noise_fileid_" + str(file_num) + noisefilename = noisefilename.replace(".", "_") + + file_num += 1 + + # store statistics + key = noisyfilename + key = key.replace(".", "_") + lang = params["split_name"].split("_")[0] + t = (key, lang) + + # verify key is unique + assert cleanfilename not in all_keys + all_keys.add(cleanfilename) + + # Split the data between training and validation based on the file number + if file_num % total_samples <= num_validation_samples: + # Write to validation set + valid_sample_keys[lang].append(key) + valid_data_tuples.append(t) + sample = { + "__key__": key, + "noisy_file": key, + "clean_file": cleanfilename, + "noise_file": noisefilename, + "clean_audio.pth": torch.tensor(clean_snr).to(torch.float32), + "noise_audio.pth": torch.tensor(noise_snr).to(torch.float32), + "noisy_audio.pth": torch.tensor(noisy_snr).to(torch.float32), + } + valid_writer.write(sample) + else: + # Write to training set + train_sample_keys[lang].append(key) + train_data_tuples.append(t) + sample = { + "__key__": key, + "noisy_file": key, + "clean_file": cleanfilename, + "noise_file": noisefilename, + "clean_audio.pth": torch.tensor(clean_snr).to(torch.float32), + "noise_audio.pth": torch.tensor(noise_snr).to(torch.float32), + "noisy_audio.pth": torch.tensor(noisy_snr).to(torch.float32), + } + train_writer.write(sample) + + train_writer.close() + valid_writer.close() + + # Write meta.json files for both training and validation + train_meta_dict = { + "language_id": lang, + "sample_keys_per_language": train_sample_keys, + "num_data_samples": len(train_data_tuples), + } + valid_meta_dict = { + "language_id": lang, + "sample_keys_per_language": valid_sample_keys, + "num_data_samples": len(valid_data_tuples), + } + + with (train_shards_path / "meta.json").open("w") as f: + json.dump(train_meta_dict, f, indent=4) + + with (valid_shards_path / "meta.json").open("w") as f: + json.dump(valid_meta_dict, f, indent=4) + + return ( + clean_source_files, + clean_clipped_files, + clean_low_activity_files, + noise_source_files, + noise_clipped_files, + noise_low_activity_files, + ) + + +def main_body(): # noqa + """Main body of this file""" + + params = dict() + + # Load hyperparameters file with command-line overrides + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Data Directories and Settings + params["split_name"] = hparams["split_name"] + + # Audio Settings + params["fs"] = int(hparams["sampling_rate"]) + params["fs_input"] = int( + hparams["input_sampling_rate"] + ) # Sampling rate of input data + params["audioformat"] = hparams["audioformat"] + params["audio_length"] = float(hparams["audio_length"]) + params["silence_length"] = float(hparams["silence_length"]) + params["total_hours"] = float(hparams["total_hours"]) + + # Clean Data Categories + params["use_singing_data"] = int(hparams["use_singing_data"]) + if hasattr(hparams, "clean_singing"): + params["clean_singing"] = str(hparams["clean_singing"]) + params["singing_choice"] = int(hparams["singing_choice"]) + + params["use_emotion_data"] = int(hparams["use_emotion_data"]) + if hasattr(hparams, "clean_emotion"): + params["clean_emotion"] = str(hparams["clean_emotion"]) + + params["use_mandarin_data"] = int(hparams["use_mandarin_data"]) + if hasattr(hparams, "clean_mandarin"): + params["clean_mandarin"] = str(hparams["clean_mandarin"]) + + # Room Impulse Response (RIR) Settings + params["rir_choice"] = int(hparams["rir_choice"]) + params["lower_t60"] = float(hparams["lower_t60"]) + params["upper_t60"] = float(hparams["upper_t60"]) + params["rir_table_csv"] = str(hparams["rir_table_csv"]) + + # File Indexing + if ( + hparams["fileindex_start"] != "None" + and hparams["fileindex_end"] != "None" + ): + params["num_files"] = int(hparams["fileindex_end"]) - int( + params["fileindex_start"] + ) + params["fileindex_start"] = int(hparams["fileindex_start"]) + params["fileindex_end"] = int(hparams["fileindex_end"]) + else: + params["num_files"] = int( + (params["total_hours"] * 60 * 60) / params["audio_length"] + ) + params["fileindex_start"] = 0 + params["fileindex_end"] = params["num_files"] + + print("Number of files to be synthesized:", params["num_files"]) + + # Data Generation and Synthesis Settings + params["is_test_set"] = utils.str2bool(str(hparams["is_test_set"])) + params["clean_activity_threshold"] = float( + hparams["clean_activity_threshold"] + ) + params["noise_activity_threshold"] = float( + hparams["noise_activity_threshold"] + ) + params["snr_lower"] = int(hparams["snr_lower"]) + params["snr_upper"] = int(hparams["snr_upper"]) + params["randomize_snr"] = utils.str2bool(str(hparams["randomize_snr"])) + params["target_level_lower"] = int(hparams["target_level_lower"]) + params["target_level_upper"] = int(hparams["target_level_upper"]) + + if hasattr(hparams, "snr"): + params["snr"] = int(hparams["snr"]) + else: + params["snr"] = int((params["snr_lower"] + params["snr_upper"]) / 2) + + # Synthesized Data Destination + params["samples_per_shard"] = hparams["samples_per_shard"] + params["train_shard_destination"] = hparams["train_shard_destination"] + params["valid_shard_destination"] = hparams["valid_shard_destination"] + + #### Shard data extraction ~~~ + # load the meta info json file + + with wds.gopen(hparams["clean_meta"], "rb") as f: + clean_meta = json.load(f) + with wds.gopen(hparams["noise_meta"], "rb") as f: + noise_meta = json.load(f) + + def audio_pipeline(sample_dict: Dict, random_chunk=True): + key = sample_dict["__key__"] + audio_tensor = sample_dict["audio.pth"] + + sig = audio_tensor.squeeze() + + return { + "sig": sig, + "id": key, + } + + clean_data = ( + wds.WebDataset( + hparams["clean_fullband_shards"], + cache_dir=hparams["shard_cache_dir"], + ) + .repeat() + .shuffle(1000) + .decode("pil") + .map(partial(audio_pipeline, random_chunk=True)) + ) + print(f"Clean data consist of {clean_meta['num_data_samples']} samples") + + noise_data = ( + wds.WebDataset( + hparams["noise_fullband_shards"], + cache_dir=hparams["shard_cache_dir"], + ) + .repeat() + .shuffle(1000) + .decode("pil") + .map(partial(audio_pipeline, random_chunk=True)) + ) + print(f"Noise data consist of {noise_meta['num_data_samples']} samples") + + params["clean_data"] = clean_data + params["noise_data"] = noise_data + + # add singing voice to clean speech + if params["use_singing_data"] == 1: + raise NotImplementedError("Add sining voice to clean speech") + else: + print("NOT using singing data for training!") + + # add emotion data to clean speech + if params["use_emotion_data"] == 1: + raise NotImplementedError("Add emotional data to clean speech") + else: + print("NOT using emotion data for training!") + + # add mandarin data to clean speech + if params["use_mandarin_data"] == 1: + raise NotImplementedError("Add Mandarin data to clean speech") + else: + print("NOT using non-english (Mandarin) data for training!") + + # rir + temp = pd.read_csv( + params["rir_table_csv"], + skiprows=[1], + sep=",", + header=None, + names=["wavfile", "channel", "T60_WB", "C50_WB", "isRealRIR"], + ) + temp.keys() + # temp.wavfile + + rir_wav = temp["wavfile"][1:] # 115413 + rir_channel = temp["channel"][1:] + rir_t60 = temp["T60_WB"][1:] + rir_isreal = temp["isRealRIR"][1:] + + rir_wav2 = [w.replace("\\", "/") for w in rir_wav] + rir_channel2 = [w for w in rir_channel] + rir_t60_2 = [w for w in rir_t60] + rir_isreal2 = [w for w in rir_isreal] + + myrir = [] + mychannel = [] + myt60 = [] + + lower_t60 = params["lower_t60"] + upper_t60 = params["upper_t60"] + + if params["rir_choice"] == 1: # real 3076 IRs + real_indices = [i for i, x in enumerate(rir_isreal2) if x == "1"] + + chosen_i = [] + for i in real_indices: + if (float(rir_t60_2[i]) >= lower_t60) and ( + float(rir_t60_2[i]) <= upper_t60 + ): + chosen_i.append(i) + + myrir = [rir_wav2[i] for i in chosen_i] + mychannel = [rir_channel2[i] for i in chosen_i] + myt60 = [rir_t60_2[i] for i in chosen_i] + + elif params["rir_choice"] == 2: # synthetic 112337 IRs + synthetic_indices = [i for i, x in enumerate(rir_isreal2) if x == "0"] + + chosen_i = [] + for i in synthetic_indices: + if (float(rir_t60_2[i]) >= lower_t60) and ( + float(rir_t60_2[i]) <= upper_t60 + ): + chosen_i.append(i) + + myrir = [rir_wav2[i] for i in chosen_i] + mychannel = [rir_channel2[i] for i in chosen_i] + myt60 = [rir_t60_2[i] for i in chosen_i] + + elif params["rir_choice"] == 3: # both real and synthetic + all_indices = [i for i, x in enumerate(rir_isreal2)] + + chosen_i = [] + for i in all_indices: + if (float(rir_t60_2[i]) >= lower_t60) and ( + float(rir_t60_2[i]) <= upper_t60 + ): + chosen_i.append(i) + + myrir = [rir_wav2[i] for i in chosen_i] + mychannel = [rir_channel2[i] for i in chosen_i] + myt60 = [rir_t60_2[i] for i in chosen_i] + + else: # default both real and synthetic + all_indices = [i for i, x in enumerate(rir_isreal2)] + + chosen_i = [] + for i in all_indices: + if (float(rir_t60_2[i]) >= lower_t60) and ( + float(rir_t60_2[i]) <= upper_t60 + ): + chosen_i.append(i) + + myrir = [rir_wav2[i] for i in chosen_i] + mychannel = [rir_channel2[i] for i in chosen_i] + myt60 = [rir_t60_2[i] for i in chosen_i] + + params["myrir"] = myrir + params["mychannel"] = mychannel + params["myt60"] = myt60 + + # Call main_gen() to generate audio + ( + clean_source_files, + clean_clipped_files, + clean_low_activity_files, + noise_source_files, + noise_clipped_files, + noise_low_activity_files, + ) = main_gen(params) + + # Create log directory if needed, and write log files of clipped and low activity files + log_dir = utils.get_dir(hparams, "log_dir", "Logs") + + utils.write_log_file( + log_dir, "source_files.csv", clean_source_files + noise_source_files + ) + utils.write_log_file( + log_dir, "clipped_files.csv", clean_clipped_files + noise_clipped_files + ) + utils.write_log_file( + log_dir, + "low_activity_files.csv", + clean_low_activity_files + noise_low_activity_files, + ) + + # Compute and print stats about percentange of clipped and low activity files + total_clean = ( + len(clean_source_files) + + len(clean_clipped_files) + + len(clean_low_activity_files) + ) + total_noise = ( + len(noise_source_files) + + len(noise_clipped_files) + + len(noise_low_activity_files) + ) + pct_clean_clipped = round(len(clean_clipped_files) / total_clean * 100, 1) + pct_noise_clipped = round(len(noise_clipped_files) / total_noise * 100, 1) + pct_clean_low_activity = round( + len(clean_low_activity_files) / total_clean * 100, 1 + ) + pct_noise_low_activity = round( + len(noise_low_activity_files) / total_noise * 100, 1 + ) + + print( + "\nOf the " + + str(total_clean) + + " clean speech files analyzed, " + + str(pct_clean_clipped) + + "% had clipping, and " + + str(pct_clean_low_activity) + + "% had low activity " + + "(below " + + str(params["clean_activity_threshold"] * 100) + + "% active percentage)" + ) + print( + "Of the " + + str(total_noise) + + " noise files analyzed, " + + str(pct_noise_clipped) + + "% had clipping, and " + + str(pct_noise_low_activity) + + "% had low activity " + + "(below " + + str(params["noise_activity_threshold"] * 100) + + "% active percentage)" + ) + + +if __name__ == "__main__": + main_body() diff --git a/recipes/DNS/noisyspeech_synthesizer/utils.py b/recipes/DNS/noisyspeech_synthesizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3826a906f347b8a5e9f1c2cb91e2d288d7fe26 --- /dev/null +++ b/recipes/DNS/noisyspeech_synthesizer/utils.py @@ -0,0 +1,54 @@ +""" +Source: https://github.com/microsoft/DNS-Challenge +Ownership: Microsoft + +* Author + rocheng +""" +import os +import csv +from shutil import copyfile +import glob + + +def get_dir(cfg, param_name, new_dir_name): + """Helper function to retrieve directory name if it exists, + create it if it doesn't exist""" + + if param_name in cfg: + dir_name = cfg[param_name] + else: + dir_name = os.path.join(os.path.dirname(__file__), new_dir_name) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + return dir_name + + +def write_log_file(log_dir, log_filename, data): + """Helper function to write log file""" + # data = zip(*data) + with open( + os.path.join(log_dir, log_filename), mode="w", newline="" + ) as csvfile: + csvwriter = csv.writer( + csvfile, delimiter=" ", quotechar="|", quoting=csv.QUOTE_MINIMAL + ) + for row in data: + csvwriter.writerow([row]) + + +def str2bool(string): + """Convert a string to a boolean value. + """ + return string.lower() in ("yes", "true", "t", "1") + + +def rename_copyfile(src_path, dest_dir, prefix="", ext="*.wav"): + """Copy and rename files from a source directory to a destination directory. + """ + srcfiles = glob.glob(f"{src_path}/" + ext) + for i in range(len(srcfiles)): + dest_path = os.path.join( + dest_dir, prefix + "_" + os.path.basename(srcfiles[i]) + ) + copyfile(srcfiles[i], dest_path) diff --git a/tests/consistency/test_recipe.py b/tests/consistency/test_recipe.py index d3ef494af30a09fb3e2b3357a3553ac8b59ccdf7..802efb1094fb001e89cdeb59af527973465552dd 100644 --- a/tests/consistency/test_recipe.py +++ b/tests/consistency/test_recipe.py @@ -23,6 +23,7 @@ def test_recipe_list( "recipes/Voicebank/MTL/CoopNet/hparams/logger.yaml", "recipes/LibriParty/generate_dataset/dataset.yaml", "hpopt.yaml", + "recipes/DNS/noisyspeech_synthesizer/noisyspeech_synthesizer.yaml", ], ): """This test checks if all the all hparam file of all the recipes are listed diff --git a/tests/recipes/DNS.csv b/tests/recipes/DNS.csv new file mode 100644 index 0000000000000000000000000000000000000000..05f26a060db9ecde83ac87c6d222679c46bec641 --- /dev/null +++ b/tests/recipes/DNS.csv @@ -0,0 +1,2 @@ +Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks,test_download +Enhancement,DNS,recipes/DNS/enhancement/train.py,recipes/DNS/enhancement/hparams/sepformer-dns-16k.yaml,recipes/DNS/create_wds_shards.py,recipes/DNS/enhancement/README.md,,https://huggingface.co/speechbrain/sepformer-dns4-16k-enhancement,--data_folder=tests/download/DNS/ --train_data=tests/download/DNS/train_shards/ --valid_data=tests/download/DNS/train_shards/ --baseline_noisy_shards_folder=tests/download/DNS/baseline/ --baseline_shards=tests/download/DNS/baseline/shard-000000.tar --N_epochs=1 --batch_size=1,"file_exists=[valid_results.csv,train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]","download_file('https://www.dropbox.com/scl/fi/i3iwzmrnyw8pgputkqvgq/DNS.zip?rlkey=1ka0g2ig4x488fg1exnxmbprd&dl=1', 'tests/download/DNS.zip', unpack=True, dest_unpack='tests/download/', write_permissions=True)"